diff --git a/create_new_architecture.py b/create_new_architecture.py new file mode 100755 index 0000000000..41b239665c --- /dev/null +++ b/create_new_architecture.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +""" +Create a new independent architecture in metatrain. + +This script copies the vanilla PET architecture and creates a new architecture +with the given name, automatically updating all imports. + +Usage: + python create_new_architecture.py pet_1 +""" + +import argparse +import shutil +import sys +from pathlib import Path +import re + + +def update_imports_in_file(filepath: Path, old_arch: str, new_arch: str) -> None: + """Update imports from old_arch to new_arch in a Python file.""" + try: + content = filepath.read_text() + + # Replace imports + content = re.sub( + rf'from metatrain\.{re.escape(old_arch)}\.', + f'from metatrain.{new_arch}.', + content + ) + content = re.sub( + rf'from metatrain\.{re.escape(old_arch)} import', + f'from metatrain.{new_arch} import', + content + ) + content = re.sub( + rf'import metatrain\.{re.escape(old_arch)}\.', + f'import metatrain.{new_arch}.', + content + ) + + # Update architecture_name entries inside checkpoints + content = re.sub( + r'"architecture_name"\s*:\s*"pet"', + f'"architecture_name": "{new_arch}"', + content, + ) + content = re.sub( + r"'architecture_name'\s*:\s*'pet'", + f"'architecture_name': '{new_arch}'", + content, + ) + + filepath.write_text(content) + except Exception as e: + print(f"Warning: Could not update {filepath}: {e}") + + +def create_architecture(arch_name: str) -> None: + """Create a new architecture by copying and modifying vanilla PET.""" + + # Normalize path + metatrain_src = Path("./src").resolve() + + # Determine paths + source_arch = metatrain_src / "metatrain" / "pet" + + # Handle nested names like "experimental.my_arch" + arch_parts = arch_name.split(".") + target_arch = metatrain_src / "metatrain" / Path(*arch_parts) + + # ======================================================================== + # Validation + # ======================================================================== + + if not source_arch.exists(): + print(f"❌ Error: Cannot find source PET architecture at: {source_arch}") + sys.exit(1) + + if target_arch.exists(): + print(f"❌ Error: Architecture '{arch_name}' already exists at: {target_arch}") + sys.exit(1) + + # ======================================================================== + # Create directory structure + # ======================================================================== + + print(f"Creating architecture: {arch_name}") + print(f"Source: {source_arch}") + print(f"Target: {target_arch}") + print() + + (target_arch / "modules").mkdir(parents=True, exist_ok=True) + (target_arch / "tests").mkdir(parents=True, exist_ok=True) + + print("✓ Created directories") + + # ======================================================================== + # Copy files + # ======================================================================== + + # Copy main files + for filename in ["__init__.py", "model.py", "trainer.py", "documentation.py"]: + src = source_arch / filename + dst = target_arch / filename + if src.exists(): + shutil.copy2(src, dst) + + print("✓ Copied main files (__init__.py, model.py, trainer.py, documentation.py)") + + # Copy modules + modules_src = source_arch / "modules" + if modules_src.exists(): + for item in modules_src.iterdir(): + if item.is_file(): + shutil.copy2(item, target_arch / "modules" / item.name) + elif item.is_dir(): + shutil.copytree(item, target_arch / "modules" / item.name, dirs_exist_ok=True) + + print("✓ Copied modules directory") + + # Copy optional files + if (source_arch / "checkpoints.py").exists(): + shutil.copy2(source_arch / "checkpoints.py", target_arch / "checkpoints.py") + print("✓ Copied checkpoints.py") + + # if (source_arch / "tests").exists(): + # for item in (source_arch / "tests").iterdir(): + # if item.is_file(): + # shutil.copy2(item, target_arch / "tests" / item.name) + # elif item.is_dir(): + # shutil.copytree(item, target_arch / "tests" / item.name, dirs_exist_ok=True) + # print("✓ Copied tests directory") + + # ======================================================================== + # Update imports + # ======================================================================== + + print() + print(f"Updating imports from 'metatrain.pet' to 'metatrain.{arch_name}'...") + + for python_file in target_arch.rglob("*.py"): + update_imports_in_file(python_file, "pet", arch_name) + + print("✓ Updated imports in all Python files") + + # ======================================================================== + # Summary + # ======================================================================== + + print() + print(f"✅ Successfully created architecture: {arch_name}") + print() + print(f"Location: {target_arch}") + print() + print("Next steps:") + print(f"1. Modify the model in: {target_arch / 'model.py'}") + print(f"2. Customize hyperparameters in: {target_arch / 'documentation.py'}") + print("3. Test the architecture:") + print(f" python -c \"from metatrain.utils.architectures import import_architecture; arch = import_architecture('{arch_name}'); print('Architecture loaded successfully!')\"") + print() + print("4. Use in training with options.yaml:") + print(" architecture:") + print(f" name: {arch_name}") + print(" model:") + print(" cutoff: 5.0") + print(" # ... other hyperparameters") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Create a new independent architecture in metatrain", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python create_new_architecture.py pet_1 + python create_new_architecture.py pet_2 + python create_new_architecture.py experimental.my_arch + python create_new_architecture.py pet_custom /path/to/metatrain/src + """, + ) + + parser.add_argument( + "name", + help="Name of the new architecture (e.g., pet_1, experimental.my_arch)", + ) + + args = parser.parse_args() + + try: + create_architecture(args.name) + except Exception as e: + print(f"❌ Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/pet-mods.txt b/pet-mods.txt new file mode 100644 index 0000000000..9f1f226d97 --- /dev/null +++ b/pet-mods.txt @@ -0,0 +1,16 @@ +Relative to un-modified PET: + +pet_1 -> biases on covariant readout layers removed +pet_2 -> Filippo's SoH embedding, with Lmax=10 +pet_3 -> shift LLFs before readout layers +pet_4 -> Filippo's SoH embedding, shift LLFs before readout layers +pet_5 -> SpH embedding (simple switch SoH -> SpH) +pet_6 -> SpH and radial polynomial embeddings +pet_7 -> SpH and explicit Bessel radial basis embedding +pet_8 -> SoH and explicit Bessel radial basis embedding +pet_9 -> Filippo's SoH embedding, simple MoE on linear readout layers +pet_10 -> all biases removed from model +pet_11 -> Filippo's SoH embedding, with Lmax=8 +pet_12 -> Filippo's SoH embedding, with Lmax=4 +pet_13 -> Filippo's SoH embedding, with Lmax=2 +pet_14 -> Filippo's SoH embedding, all biases removed \ No newline at end of file diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 757e778140..75abd311a6 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -16,6 +16,7 @@ NeighborListOptions, System, ) +from torch.utils.hooks import RemovableHandle from metatrain.utils.abc import ModelInterface from metatrain.utils.additive import ZBL, CompositionModel @@ -29,7 +30,7 @@ from . import checkpoints from .documentation import ModelHypers from .modules.finetuning import apply_finetuning_strategy -from .modules.structures import systems_to_batch +from .modules.structures import get_pair_sample_labels, systems_to_batch from .modules.transformer import CartesianTransformer @@ -405,27 +406,44 @@ def forward( if self.single_label.values.device != device: self._move_labels_to_device(device) - with torch.profiler.record_function("PET::systems_to_batch"): - # **Stage 0: Input Preparation** - ( - element_indices_nodes, - element_indices_neighbors, - edge_vectors, - edge_distances, - padding_mask, - reverse_neighbor_index, - cutoff_factors, - system_indices, + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, sample_labels, - ) = systems_to_batch( - systems, - nl_options, - self.atomic_types, - self.species_to_species_index, - self.cutoff_function, - self.cutoff_width, - self.num_neighbors_adaptive, - selected_atoms, + pair_sample_labels, ) # the scaled_dot_product_attention function from torch cannot do @@ -443,6 +461,17 @@ def forward( padding_mask=padding_mask, cutoff_factors=cutoff_factors, ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) node_features_list, edge_features_list = self._calculate_features( featurizer_inputs, use_manual_attention=use_manual_attention, @@ -568,8 +597,202 @@ def forward( return_dict[name].keys, output_blocks ) + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + return return_dict + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + def _calculate_features( self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: diff --git a/src/metatrain/pet/modules/structures.py b/src/metatrain/pet/modules/structures.py index 8438a7c67e..f7e67f9eaf 100644 --- a/src/metatrain/pet/modules/structures.py +++ b/src/metatrain/pet/modules/structures.py @@ -155,6 +155,8 @@ def systems_to_batch( torch.Tensor, torch.Tensor, Labels, + torch.Tensor, + torch.Tensor, ]: """ Converts a list of systems to a batch required for the PET model. @@ -321,4 +323,55 @@ def systems_to_batch( cutoff_factors, system_indices, sample_labels, + centers, + nef_to_edges_neighbor, ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index fae970836a..49e82212bc 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -357,13 +357,31 @@ def train( # Log the initial learning rate: logging.info(f"Base learning rate: {self.hypers['learning_rate']}") - start_epoch = 0 if self.epoch is None else self.epoch + 1 + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch - # Train the model: + # Save the untrained model checkpoint: if self.best_metric is None: self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: logging.info("Starting training") - epoch = start_epoch + start_epoch = start_epoch + 1 for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): if is_distributed: diff --git a/src/metatrain/pet_1/__init__.py b/src/metatrain/pet_1/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_1/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_1/checkpoints.py b/src/metatrain/pet_1/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_1/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_1/documentation.py b/src/metatrain/pet_1/documentation.py new file mode 100644 index 0000000000..a243e5d175 --- /dev/null +++ b/src/metatrain/pet_1/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_1.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_1/model.py b/src/metatrain/pet_1/model.py new file mode 100644 index 0000000000..bdd8a9c0a2 --- /dev/null +++ b/src/metatrain/pet_1/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=("o3_lambda_0" in key and "o3_sigma_1" in key), + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=("o3_lambda_0" in key and "o3_sigma_1" in key), + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_1", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_1/modules/adaptive_cutoff.py b/src/metatrain/pet_1/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_1/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_1/modules/finetuning.py b/src/metatrain/pet_1/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_1/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_1/modules/nef.py b/src/metatrain/pet_1/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_1/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_1/modules/structures.py b/src/metatrain/pet_1/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_1/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_1/modules/transformer.py b/src/metatrain/pet_1/modules/transformer.py new file mode 100644 index 0000000000..5907314fec --- /dev/null +++ b/src/metatrain/pet_1/modules/transformer.py @@ -0,0 +1,548 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.edge_embedder = nn.Linear(4, d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.edge_embedder(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output diff --git a/src/metatrain/pet_1/modules/utilities.py b/src/metatrain/pet_1/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_1/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_1/trainer.py b/src/metatrain/pet_1/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_1/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_10/__init__.py b/src/metatrain/pet_10/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_10/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_10/checkpoints.py b/src/metatrain/pet_10/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_10/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_10/documentation.py b/src/metatrain/pet_10/documentation.py new file mode 100644 index 0000000000..a48c0abf45 --- /dev/null +++ b/src/metatrain/pet_10/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_10.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_10/model.py b/src/metatrain/pet_10/model.py new file mode 100644 index 0000000000..4a3aea9e3c --- /dev/null +++ b/src/metatrain/pet_10/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet, bias=False) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet, bias=False), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet, bias=False), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head, bias=False), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head, bias=False), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head, bias=False), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head, bias=False), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=False, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=False, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_10", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_10/modules/adaptive_cutoff.py b/src/metatrain/pet_10/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_10/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_10/modules/finetuning.py b/src/metatrain/pet_10/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_10/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_10/modules/nef.py b/src/metatrain/pet_10/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_10/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_10/modules/structures.py b/src/metatrain/pet_10/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_10/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_10/modules/transformer.py b/src/metatrain/pet_10/modules/transformer.py new file mode 100644 index 0000000000..3cc75596dc --- /dev/null +++ b/src/metatrain/pet_10/modules/transformer.py @@ -0,0 +1,555 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward, bias=False) + self.w_out = nn.Linear(dim_feedforward, d_model, bias=False) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward, bias=False) + self.w_out = nn.Linear(dim_feedforward, d_model, bias=False) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim, bias=False) + self.output_linear = nn.Linear(total_dim, total_dim, bias=False) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + if norm == "LayerNorm": + self.norm_attention = norm_class(d_model, bias=False) + self.norm_mlp = norm_class(d_model, bias=False) + else: + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model, bias=False) + self.center_expansion = nn.Linear(d_model, dim_node_features, bias=False) + if norm == "LayerNorm": + self.norm_center_features = norm_class(dim_node_features, bias=False) + else: + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.edge_embedder = nn.Linear(4, d_model, bias=False) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model, bias=False), + torch.nn.SiLU(), + nn.Linear(d_model, d_model, bias=False), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.edge_embedder(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output diff --git a/src/metatrain/pet_10/modules/utilities.py b/src/metatrain/pet_10/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_10/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_10/trainer.py b/src/metatrain/pet_10/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_10/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_11/__init__.py b/src/metatrain/pet_11/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_11/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_11/checkpoints.py b/src/metatrain/pet_11/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_11/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_11/documentation.py b/src/metatrain/pet_11/documentation.py new file mode 100644 index 0000000000..c0bcc5011a --- /dev/null +++ b/src/metatrain/pet_11/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_11.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_11/model.py b/src/metatrain/pet_11/model.py new file mode 100644 index 0000000000..bc4e151407 --- /dev/null +++ b/src/metatrain/pet_11/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_11", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_11/modules/adaptive_cutoff.py b/src/metatrain/pet_11/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_11/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_11/modules/finetuning.py b/src/metatrain/pet_11/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_11/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_11/modules/nef.py b/src/metatrain/pet_11/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_11/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_11/modules/structures.py b/src/metatrain/pet_11/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_11/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_11/modules/transformer.py b/src/metatrain/pet_11/modules/transformer.py new file mode 100644 index 0000000000..883c37a0ac --- /dev/null +++ b/src/metatrain/pet_11/modules/transformer.py @@ -0,0 +1,556 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SolidHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.l_max = 2 + self.spherical_harmonics = SolidHarmonics(l_max=self.l_max) + self.edge_embedder = nn.Linear((self.l_max + 1) ** 2, d_model) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + # edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + # edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.spherical_harmonics( + edge_vectors.reshape(-1, 3) + ).reshape(edge_vectors.shape[0], edge_vectors.shape[1], -1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output \ No newline at end of file diff --git a/src/metatrain/pet_11/modules/utilities.py b/src/metatrain/pet_11/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_11/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_11/trainer.py b/src/metatrain/pet_11/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_11/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_12/__init__.py b/src/metatrain/pet_12/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_12/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_12/checkpoints.py b/src/metatrain/pet_12/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_12/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_12/documentation.py b/src/metatrain/pet_12/documentation.py new file mode 100644 index 0000000000..052e523fe4 --- /dev/null +++ b/src/metatrain/pet_12/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_12.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_12/model.py b/src/metatrain/pet_12/model.py new file mode 100644 index 0000000000..cc2104a268 --- /dev/null +++ b/src/metatrain/pet_12/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_12", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_12/modules/adaptive_cutoff.py b/src/metatrain/pet_12/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_12/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_12/modules/finetuning.py b/src/metatrain/pet_12/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_12/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_12/modules/nef.py b/src/metatrain/pet_12/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_12/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_12/modules/structures.py b/src/metatrain/pet_12/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_12/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_12/modules/transformer.py b/src/metatrain/pet_12/modules/transformer.py new file mode 100644 index 0000000000..b3ee773a63 --- /dev/null +++ b/src/metatrain/pet_12/modules/transformer.py @@ -0,0 +1,556 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SolidHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.l_max = 4 + self.spherical_harmonics = SolidHarmonics(l_max=self.l_max) + self.edge_embedder = nn.Linear((self.l_max + 1) ** 2, d_model) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + # edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + # edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.spherical_harmonics( + edge_vectors.reshape(-1, 3) + ).reshape(edge_vectors.shape[0], edge_vectors.shape[1], -1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output \ No newline at end of file diff --git a/src/metatrain/pet_12/modules/utilities.py b/src/metatrain/pet_12/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_12/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_12/trainer.py b/src/metatrain/pet_12/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_12/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_13/__init__.py b/src/metatrain/pet_13/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_13/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_13/checkpoints.py b/src/metatrain/pet_13/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_13/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_13/documentation.py b/src/metatrain/pet_13/documentation.py new file mode 100644 index 0000000000..a0d2719f32 --- /dev/null +++ b/src/metatrain/pet_13/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_13.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_13/model.py b/src/metatrain/pet_13/model.py new file mode 100644 index 0000000000..d6b18cd905 --- /dev/null +++ b/src/metatrain/pet_13/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_13", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_13/modules/adaptive_cutoff.py b/src/metatrain/pet_13/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_13/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_13/modules/finetuning.py b/src/metatrain/pet_13/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_13/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_13/modules/nef.py b/src/metatrain/pet_13/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_13/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_13/modules/structures.py b/src/metatrain/pet_13/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_13/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_13/modules/transformer.py b/src/metatrain/pet_13/modules/transformer.py new file mode 100644 index 0000000000..e6836ae676 --- /dev/null +++ b/src/metatrain/pet_13/modules/transformer.py @@ -0,0 +1,556 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SolidHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.l_max = 8 + self.spherical_harmonics = SolidHarmonics(l_max=self.l_max) + self.edge_embedder = nn.Linear((self.l_max + 1) ** 2, d_model) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + # edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + # edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.spherical_harmonics( + edge_vectors.reshape(-1, 3) + ).reshape(edge_vectors.shape[0], edge_vectors.shape[1], -1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output \ No newline at end of file diff --git a/src/metatrain/pet_13/modules/utilities.py b/src/metatrain/pet_13/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_13/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_13/trainer.py b/src/metatrain/pet_13/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_13/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_14/__init__.py b/src/metatrain/pet_14/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_14/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_14/checkpoints.py b/src/metatrain/pet_14/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_14/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_14/documentation.py b/src/metatrain/pet_14/documentation.py new file mode 100644 index 0000000000..5dc246c608 --- /dev/null +++ b/src/metatrain/pet_14/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_14.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_14/model.py b/src/metatrain/pet_14/model.py new file mode 100644 index 0000000000..57dc519929 --- /dev/null +++ b/src/metatrain/pet_14/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet, bias=False) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet, bias=False), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet, bias=False), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head, bias=False), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head, bias=False), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head, bias=False), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head, bias=False), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=("o3_lambda_)" in key and "o3_sigma_1" in key), + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=("o3_lambda_)" in key and "o3_sigma_1" in key), + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_14", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_14/modules/adaptive_cutoff.py b/src/metatrain/pet_14/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_14/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_14/modules/finetuning.py b/src/metatrain/pet_14/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_14/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_14/modules/nef.py b/src/metatrain/pet_14/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_14/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_14/modules/structures.py b/src/metatrain/pet_14/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_14/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_14/modules/transformer.py b/src/metatrain/pet_14/modules/transformer.py new file mode 100644 index 0000000000..6f63402da2 --- /dev/null +++ b/src/metatrain/pet_14/modules/transformer.py @@ -0,0 +1,563 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SolidHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward, bias=False) + self.w_out = nn.Linear(dim_feedforward, d_model, bias=False) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward, bias=False) + self.w_out = nn.Linear(dim_feedforward, d_model, bias=False) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim, bias=False) + self.output_linear = nn.Linear(total_dim, total_dim, bias=False) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + if norm == "LayerNorm": + self.norm_attention = norm_class(d_model, bias=False) + self.norm_mlp = norm_class(d_model, bias=False) + else: + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model, bias=False) + self.center_expansion = nn.Linear(d_model, dim_node_features, bias=False) + if norm == "LayerNorm": + self.norm_center_features = norm_class(dim_node_features, bias=False) + else: + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.l_max = 10 + self.spherical_harmonics = SolidHarmonics(l_max=self.l_max) + self.edge_embedder = nn.Linear((self.l_max + 1) ** 2, d_model, bias=False) + self.rmsnorm = nn.LayerNorm(d_model, bias=False) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model, bias=False), + torch.nn.SiLU(), + nn.Linear(d_model, d_model, bias=False), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + # edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + # edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.spherical_harmonics( + edge_vectors.reshape(-1, 3) + ).reshape(edge_vectors.shape[0], edge_vectors.shape[1], -1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output \ No newline at end of file diff --git a/src/metatrain/pet_14/modules/utilities.py b/src/metatrain/pet_14/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_14/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_14/trainer.py b/src/metatrain/pet_14/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_14/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_2/__init__.py b/src/metatrain/pet_2/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_2/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_2/checkpoints.py b/src/metatrain/pet_2/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_2/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_2/documentation.py b/src/metatrain/pet_2/documentation.py new file mode 100644 index 0000000000..612b03d8a1 --- /dev/null +++ b/src/metatrain/pet_2/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_2.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_2/model.py b/src/metatrain/pet_2/model.py new file mode 100644 index 0000000000..38d3a07885 --- /dev/null +++ b/src/metatrain/pet_2/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_2", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_2/modules/adaptive_cutoff.py b/src/metatrain/pet_2/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_2/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_2/modules/finetuning.py b/src/metatrain/pet_2/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_2/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_2/modules/nef.py b/src/metatrain/pet_2/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_2/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_2/modules/structures.py b/src/metatrain/pet_2/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_2/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_2/modules/transformer.py b/src/metatrain/pet_2/modules/transformer.py new file mode 100644 index 0000000000..decb395090 --- /dev/null +++ b/src/metatrain/pet_2/modules/transformer.py @@ -0,0 +1,555 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SolidHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.spherical_harmonics = SolidHarmonics(l_max=10) + self.edge_embedder = nn.Linear(11**2, d_model) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + # edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + # edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.spherical_harmonics( + edge_vectors.reshape(-1, 3) + ).reshape(edge_vectors.shape[0], edge_vectors.shape[1], -1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output \ No newline at end of file diff --git a/src/metatrain/pet_2/modules/utilities.py b/src/metatrain/pet_2/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_2/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_2/trainer.py b/src/metatrain/pet_2/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_2/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_3/__init__.py b/src/metatrain/pet_3/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_3/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_3/checkpoints.py b/src/metatrain/pet_3/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_3/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_3/documentation.py b/src/metatrain/pet_3/documentation.py new file mode 100644 index 0000000000..dd925eb82c --- /dev/null +++ b/src/metatrain/pet_3/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_3.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_3/model.py b/src/metatrain/pet_3/model.py new file mode 100644 index 0000000000..27805597d5 --- /dev/null +++ b/src/metatrain/pet_3/model.py @@ -0,0 +1,1658 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.readout import ReadoutLayer +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layer_shifts = torch.nn.ModuleDict() # not needed + self.edge_last_layer_shifts = torch.nn.ModuleDict() # not needed + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + systems, + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook(module: torch.nn.Module, inp: Any, outp: Any) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + systems: List[System], + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + batch_species_indices = self.species_to_species_index[ + torch.cat([system.types for system in systems], dim=0) + ] + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block( + batch_species_indices, + node_last_layer_features, + ) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + batch_species_indices, edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: ReadoutLayer( + self.d_head, + prod(shape), + len(self.atomic_types), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: ReadoutLayer( + self.d_head, + prod(shape), + len(self.atomic_types), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_3", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_3/modules/adaptive_cutoff.py b/src/metatrain/pet_3/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_3/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_3/modules/finetuning.py b/src/metatrain/pet_3/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_3/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_3/modules/nef.py b/src/metatrain/pet_3/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_3/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_3/modules/readout.py b/src/metatrain/pet_3/modules/readout.py new file mode 100644 index 0000000000..e5eaaae74b --- /dev/null +++ b/src/metatrain/pet_3/modules/readout.py @@ -0,0 +1,32 @@ +from typing import List +import torch + +from metatomic.torch import System + + +class ReadoutLayer(torch.nn.Module): + def __init__( + self, + feature_dim: int, + output_dim: int, + num_atomic_types: int, + bias: bool = True, + ): + super().__init__() + self.shift = torch.nn.Embedding( + num_atomic_types, feature_dim + ) + self.linear = torch.nn.Linear(feature_dim, output_dim, bias=bias) + + def forward( + self, + batch_species_indices: torch.Tensor, + features: torch.Tensor, + ) -> torch.Tensor: + + shift = self.shift(batch_species_indices) + if features.dim() == 3: # edge features + shift = shift.unsqueeze(1) + features = features + shift + + return self.linear(features) \ No newline at end of file diff --git a/src/metatrain/pet_3/modules/structures.py b/src/metatrain/pet_3/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_3/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_3/modules/transformer.py b/src/metatrain/pet_3/modules/transformer.py new file mode 100644 index 0000000000..5907314fec --- /dev/null +++ b/src/metatrain/pet_3/modules/transformer.py @@ -0,0 +1,548 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.edge_embedder = nn.Linear(4, d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.edge_embedder(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output diff --git a/src/metatrain/pet_3/modules/utilities.py b/src/metatrain/pet_3/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_3/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_3/trainer.py b/src/metatrain/pet_3/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_3/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_4/__init__.py b/src/metatrain/pet_4/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_4/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_4/checkpoints.py b/src/metatrain/pet_4/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_4/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_4/documentation.py b/src/metatrain/pet_4/documentation.py new file mode 100644 index 0000000000..c6f563a906 --- /dev/null +++ b/src/metatrain/pet_4/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_4.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_4/model.py b/src/metatrain/pet_4/model.py new file mode 100644 index 0000000000..b830d73ec1 --- /dev/null +++ b/src/metatrain/pet_4/model.py @@ -0,0 +1,1658 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.readout import ReadoutLayer +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layer_shifts = torch.nn.ModuleDict() + self.edge_last_layer_shifts = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + systems, + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook(module: torch.nn.Module, inp: Any, outp: Any) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + systems: List[System], + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + batch_species_indices = self.species_to_species_index[ + torch.cat([system.types for system in systems], dim=0) + ] + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block( + batch_species_indices, + node_last_layer_features, + ) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + batch_species_indices, edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: ReadoutLayer( + self.d_head, + prod(shape), + len(self.atomic_types), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: ReadoutLayer( + self.d_head, + prod(shape), + len(self.atomic_types), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_4", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_4/modules/adaptive_cutoff.py b/src/metatrain/pet_4/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_4/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_4/modules/finetuning.py b/src/metatrain/pet_4/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_4/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_4/modules/nef.py b/src/metatrain/pet_4/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_4/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_4/modules/readout.py b/src/metatrain/pet_4/modules/readout.py new file mode 100644 index 0000000000..e5eaaae74b --- /dev/null +++ b/src/metatrain/pet_4/modules/readout.py @@ -0,0 +1,32 @@ +from typing import List +import torch + +from metatomic.torch import System + + +class ReadoutLayer(torch.nn.Module): + def __init__( + self, + feature_dim: int, + output_dim: int, + num_atomic_types: int, + bias: bool = True, + ): + super().__init__() + self.shift = torch.nn.Embedding( + num_atomic_types, feature_dim + ) + self.linear = torch.nn.Linear(feature_dim, output_dim, bias=bias) + + def forward( + self, + batch_species_indices: torch.Tensor, + features: torch.Tensor, + ) -> torch.Tensor: + + shift = self.shift(batch_species_indices) + if features.dim() == 3: # edge features + shift = shift.unsqueeze(1) + features = features + shift + + return self.linear(features) \ No newline at end of file diff --git a/src/metatrain/pet_4/modules/structures.py b/src/metatrain/pet_4/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_4/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_4/modules/transformer.py b/src/metatrain/pet_4/modules/transformer.py new file mode 100644 index 0000000000..decb395090 --- /dev/null +++ b/src/metatrain/pet_4/modules/transformer.py @@ -0,0 +1,555 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SolidHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.spherical_harmonics = SolidHarmonics(l_max=10) + self.edge_embedder = nn.Linear(11**2, d_model) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + # edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + # edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.spherical_harmonics( + edge_vectors.reshape(-1, 3) + ).reshape(edge_vectors.shape[0], edge_vectors.shape[1], -1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output \ No newline at end of file diff --git a/src/metatrain/pet_4/modules/utilities.py b/src/metatrain/pet_4/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_4/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_4/trainer.py b/src/metatrain/pet_4/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_4/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_5/__init__.py b/src/metatrain/pet_5/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_5/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_5/checkpoints.py b/src/metatrain/pet_5/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_5/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_5/documentation.py b/src/metatrain/pet_5/documentation.py new file mode 100644 index 0000000000..6f24f3abdb --- /dev/null +++ b/src/metatrain/pet_5/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_5.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_5/model.py b/src/metatrain/pet_5/model.py new file mode 100644 index 0000000000..25bda7c128 --- /dev/null +++ b/src/metatrain/pet_5/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_5", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_5/modules/adaptive_cutoff.py b/src/metatrain/pet_5/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_5/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_5/modules/finetuning.py b/src/metatrain/pet_5/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_5/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_5/modules/nef.py b/src/metatrain/pet_5/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_5/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_5/modules/structures.py b/src/metatrain/pet_5/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_5/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_5/modules/transformer.py b/src/metatrain/pet_5/modules/transformer.py new file mode 100644 index 0000000000..9993e4fd8a --- /dev/null +++ b/src/metatrain/pet_5/modules/transformer.py @@ -0,0 +1,555 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SphericalHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.spherical_harmonics = SphericalHarmonics(l_max=10) + self.edge_embedder = nn.Linear(11**2, d_model) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + # edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + # edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.spherical_harmonics( + edge_vectors.reshape(-1, 3) + ).reshape(edge_vectors.shape[0], edge_vectors.shape[1], -1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output \ No newline at end of file diff --git a/src/metatrain/pet_5/modules/utilities.py b/src/metatrain/pet_5/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_5/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_5/trainer.py b/src/metatrain/pet_5/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_5/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_6/__init__.py b/src/metatrain/pet_6/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_6/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_6/checkpoints.py b/src/metatrain/pet_6/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_6/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_6/documentation.py b/src/metatrain/pet_6/documentation.py new file mode 100644 index 0000000000..8c58a78eb7 --- /dev/null +++ b/src/metatrain/pet_6/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_6.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_6/model.py b/src/metatrain/pet_6/model.py new file mode 100644 index 0000000000..b787301ea1 --- /dev/null +++ b/src/metatrain/pet_6/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_6", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_6/modules/adaptive_cutoff.py b/src/metatrain/pet_6/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_6/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_6/modules/finetuning.py b/src/metatrain/pet_6/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_6/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_6/modules/nef.py b/src/metatrain/pet_6/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_6/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_6/modules/structures.py b/src/metatrain/pet_6/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_6/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_6/modules/transformer.py b/src/metatrain/pet_6/modules/transformer.py new file mode 100644 index 0000000000..b1e5349afc --- /dev/null +++ b/src/metatrain/pet_6/modules/transformer.py @@ -0,0 +1,562 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SphericalHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.l_max = 10 + self.spherical_harmonics = SphericalHarmonics(l_max=self.l_max) + self.edge_embedder = nn.Linear( + (self.l_max + 1) ** 2 + (self.l_max + 1), d_model + ) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + # edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + # edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.spherical_harmonics(edge_vectors.reshape(-1, 3)).reshape( + edge_vectors.shape[0], edge_vectors.shape[1], -1 + ) + radial_poly = torch.norm(edge_vectors, dim=-1, keepdim=True).repeat( + 1, 1, self.l_max + 1 + ) ** torch.arange(self.l_max + 1, device=edge_vectors.device).reshape(1, 1, -1) + edge_embeddings = torch.concat([edge_embeddings, radial_poly], dim=-1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output diff --git a/src/metatrain/pet_6/modules/utilities.py b/src/metatrain/pet_6/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_6/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_6/trainer.py b/src/metatrain/pet_6/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_6/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_7/__init__.py b/src/metatrain/pet_7/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_7/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_7/checkpoints.py b/src/metatrain/pet_7/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_7/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_7/documentation.py b/src/metatrain/pet_7/documentation.py new file mode 100644 index 0000000000..f394ff2081 --- /dev/null +++ b/src/metatrain/pet_7/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_7.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_7/model.py b/src/metatrain/pet_7/model.py new file mode 100644 index 0000000000..b62444b0b8 --- /dev/null +++ b/src/metatrain/pet_7/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_7", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_7/modules/adaptive_cutoff.py b/src/metatrain/pet_7/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_7/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_7/modules/finetuning.py b/src/metatrain/pet_7/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_7/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_7/modules/nef.py b/src/metatrain/pet_7/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_7/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_7/modules/structures.py b/src/metatrain/pet_7/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_7/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_7/modules/transformer.py b/src/metatrain/pet_7/modules/transformer.py new file mode 100644 index 0000000000..c68ad54358 --- /dev/null +++ b/src/metatrain/pet_7/modules/transformer.py @@ -0,0 +1,571 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SphericalHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.l_max = 10 + self.n_rad = 32 + self.spherical_harmonics = SphericalHarmonics(l_max=self.l_max) + + self.edge_embedder = nn.Linear((self.l_max + 1) ** 2 + self.n_rad, d_model) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + + # Compute spherical harmonics + sph = self.spherical_harmonics(edge_vectors.reshape(-1, 3)).reshape( + edge_vectors.shape[0], edge_vectors.shape[1], -1 + ) + + # Compute Bessel RBFs + eps = 1e-8 + r = (edge_distances / self.cutoff).clamp(0.0, 1.0) + x = ( + torch.pi + * r[:, :, None] + * torch.arange(1, self.n_rad + 1, device=edge_vectors.device)[None, None, :] + ) + rbf = torch.sin(x) / (r[:, :, None] + eps) + rbf = rbf * cutoff_factors[:, :, None] + + edge_embeddings = torch.concat([sph, rbf], dim=-1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output diff --git a/src/metatrain/pet_7/modules/utilities.py b/src/metatrain/pet_7/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_7/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_7/trainer.py b/src/metatrain/pet_7/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_7/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_8/__init__.py b/src/metatrain/pet_8/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_8/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_8/checkpoints.py b/src/metatrain/pet_8/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_8/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_8/documentation.py b/src/metatrain/pet_8/documentation.py new file mode 100644 index 0000000000..b4682a7cca --- /dev/null +++ b/src/metatrain/pet_8/documentation.py @@ -0,0 +1,251 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_8.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_8/model.py b/src/metatrain/pet_8/model.py new file mode 100644 index 0000000000..b67afaeee7 --- /dev/null +++ b/src/metatrain/pet_8/model.py @@ -0,0 +1,1646 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook( + module: torch.nn.Module, inp: Any, outp: Any + ) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block(node_last_layer_features) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: torch.nn.Linear( + self.d_head, + prod(shape), + bias=True, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_8", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_8/modules/adaptive_cutoff.py b/src/metatrain/pet_8/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_8/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_8/modules/finetuning.py b/src/metatrain/pet_8/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_8/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_8/modules/nef.py b/src/metatrain/pet_8/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_8/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_8/modules/structures.py b/src/metatrain/pet_8/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_8/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_8/modules/transformer.py b/src/metatrain/pet_8/modules/transformer.py new file mode 100644 index 0000000000..8682d5d584 --- /dev/null +++ b/src/metatrain/pet_8/modules/transformer.py @@ -0,0 +1,571 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SolidHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.l_max = 10 + self.n_rad = 32 + self.spherical_harmonics = SolidHarmonics(l_max=self.l_max) + + self.edge_embedder = nn.Linear((self.l_max + 1) ** 2 + self.n_rad, d_model) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + + # Compute spherical harmonics + sph = self.spherical_harmonics(edge_vectors.reshape(-1, 3)).reshape( + edge_vectors.shape[0], edge_vectors.shape[1], -1 + ) + + # Compute Bessel RBFs + eps = 1e-8 + r = (edge_distances / self.cutoff).clamp(0.0, 1.0) + x = ( + torch.pi + * r[:, :, None] + * torch.arange(1, self.n_rad + 1, device=edge_vectors.device)[None, None, :] + ) + rbf = torch.sin(x) / (r[:, :, None] + eps) + rbf = rbf * cutoff_factors[:, :, None] + + edge_embeddings = torch.concat([sph, rbf], dim=-1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output diff --git a/src/metatrain/pet_8/modules/utilities.py b/src/metatrain/pet_8/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_8/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_8/trainer.py b/src/metatrain/pet_8/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_8/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/pet_9/__init__.py b/src/metatrain/pet_9/__init__.py new file mode 100644 index 0000000000..01db7d8aa9 --- /dev/null +++ b/src/metatrain/pet_9/__init__.py @@ -0,0 +1,21 @@ +from .model import PET +from .trainer import Trainer + + +__model__ = PET +__trainer__ = Trainer +__capabilities__ = { + "supported_devices": __model__.__supported_devices__, + "supported_dtypes": __model__.__supported_dtypes__, +} + +__authors__ = [ + ("Sergey Pozdnyakov ", "@spozdn"), + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Arslan Mazitov ", "@abmazitov"), + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/pet_9/checkpoints.py b/src/metatrain/pet_9/checkpoints.py new file mode 100644 index 0000000000..020b1e6590 --- /dev/null +++ b/src/metatrain/pet_9/checkpoints.py @@ -0,0 +1,422 @@ +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +########################### +# MODEL ################### +########################### + + +def model_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "additive_models.0.model.type_to_index" not in state_dict: + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "train_hypers" in checkpoint: + finetune_config = checkpoint["train_hypers"].get("finetune", {}) + else: + finetune_config = {} + state_dict["finetune_config"] = finetune_config + + +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["epoch"] = checkpoint.get("epoch") + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + if checkpoint["best_model_state_dict"] is not None: + checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict") + + +def model_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["model_data"]["dataset_info"]._atomic_types = list( + checkpoint["model_data"]["dataset_info"]._atomic_types + ) + + +def model_update_v5_v6(checkpoint: dict) -> None: + """ + Update a v5 checkpoint to v6. + + :param checkpoint: The checkpoint to update. + """ + if not checkpoint["best_model_state_dict"]: + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + +def model_update_v6_v7(checkpoint: dict) -> None: + """ + Update model checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to be updated. + """ + # this update consists in changes in the scaler + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if ( + "scaler.scales" not in state_dict + and "scaler.dummy_buffer" in state_dict + and "scaler.model.type_to_index" in state_dict + ): + continue # already updated + old_scales_tensor = state_dict.pop("scaler.scales") + old_output_name_to_output_index = {} + for target_index, target_name in enumerate( + checkpoint["model_data"]["dataset_info"].targets.keys() + ): + old_output_name_to_output_index[target_name] = target_index + state_dict["scaler.dummy_buffer"] = torch.tensor( + [0.0], dtype=old_scales_tensor.dtype + ) + state_dict["scaler.model.type_to_index"] = state_dict[ + "additive_models.0.model.type_to_index" + ] + for target_name, target_info in checkpoint["model_data"][ + "dataset_info" + ].targets.items(): + layout = target_info.layout + if layout.sample_names == ["system"]: + samples = Labels(["atomic_type"], torch.tensor([[-1]])) + + elif layout.sample_names == ["system", "atom"]: + samples = Labels( + ["atomic_type"], + torch.arange( + len(checkpoint["model_data"]["dataset_info"].atomic_types) + ).reshape(-1, 1), + ) + else: + raise ValueError( # will never happen + "Unknown sample kind. Please contact the developers." + ) + scales_tensormap = TensorMap( + keys=layout.keys, + blocks=[ + TensorBlock( + values=torch.full( # important when scale_targets=False + (len(samples), len(block.properties)), + old_scales_tensor[ + old_output_name_to_output_index[target_name] + ], + dtype=torch.float64, + ), + samples=samples, + components=[], + properties=block.properties, + ) + for block in layout.blocks() + ], + ) + state_dict[f"scaler.{target_name}_scaler_buffer"] = mts.save_buffer( + mts.make_contiguous(scales_tensormap) + ) + + +def model_update_v7_v8(checkpoint: dict) -> None: + """ + Update a v7 checkpoint to v8. + + :param checkpoint: The checkpoint to update. + """ + if "node_embedding.weight" in checkpoint["model_state_dict"]: + ############################################## + # **Updating the large-scale PET checkpoints** + ############################################## + + checkpoint["model_data"]["model_hypers"]["normalization"] = "RMSNorm" + checkpoint["model_data"]["model_hypers"]["activation"] = "SwiGLU" + checkpoint["model_data"]["model_hypers"]["d_node"] = ( + 4 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PreLN" + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "feedforward" + checkpoint["model_data"]["model_hypers"]["d_feedforward"] = ( + 2 * checkpoint["model_data"]["model_hypers"]["d_pet"] + ) + if (state_dict := checkpoint.get("model_state_dict")) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + if "embedding." in k and "node" not in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if "node_embedding." in k: + k = k.replace("node_embedding.", "node_embedders.0.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + if "combination_rmsnorms" in k: + k = k.replace("combination_rmsnorms", "combination_norms") + new_state_dict[k] = v + checkpoint["model_state_dict"] = new_state_dict + + # for the large-scale checkpoints, the model for evaluation is always + # taken to be the last + checkpoint["best_model_state_dict"] = checkpoint["model_state_dict"] + + else: + ############################################### + # **Updating the standard old PET checkpoints** + ############################################### + + # Adding the option for choosing the normalization type + if "normalization" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["normalization"] = "LayerNorm" + # Adding the option for choosing the activation function + if "activation" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["activation"] = "SiLU" + # Setting the node features dimension to be the same as d_pet if not specified + if "d_node" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["d_node"] = checkpoint[ + "model_data" + ]["model_hypers"]["d_pet"] + # Setting the default transformer type to PostLN if not specified + if "transformer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["transformer_type"] = "PostLN" + if "featurizer_type" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["featurizer_type"] = "residual" + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if ".mlp.0" in k: + k = k.replace(".mlp.0", ".mlp.w_in") + if ".mlp.3" in k: + k = k.replace(".mlp.3", ".mlp.w_out") + # Moving the node embedder to a top-level PET model attribute + if "embedding." in k: + k = k.replace("embedding.", "edge_embedder.") + v = v[:-1, :] # removing the embedding for padding +1 type + if ".node_embedder." in k: + key_content = k.split(".") + k = ".".join( + ["node_embedders", key_content[1], key_content[-1]] + ) + v = v[:-1, :] # removing the embedding for padding +1 type + if ".neighbor_embedder." in k: + v = v[:-1, :] # removing the embedding for padding +1 type + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + +def model_update_v8_v9(checkpoint: dict) -> None: + """ + Update a v8 checkpoint to v9. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + if "finetune_config" in state_dict: + if "inherit_heads" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["inherit_heads"] = {} + if "method" not in state_dict["finetune_config"]: + state_dict["finetune_config"]["method"] = "full" + + +def model_update_v9_v10(checkpoint: dict) -> None: + """ + Update a v9 checkpoint to v10. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_neighbors_adaptive hyperparameter if not present + if "num_neighbors_adaptive" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["num_neighbors_adaptive"] = None + if "cutoff_function" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" + + +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + # Adding the attention_temperature hyperparameter if not present + if "attention_temperature" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 + + +########################### +# TRAINER ################# +########################### + + +def trainer_update_v1_v2(checkpoint: dict) -> None: + """ + Update a v1 Trainer checkpoint to v2. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) + + +def trainer_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 Trainer checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 Trainer checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v4_v5(checkpoint: dict) -> None: + """ + Update a v4 Trainer checkpoint to v5. + + :param checkpoint: The checkpoint to update. + """ + raise ValueError( + "In order to use this checkpoint, you need metatrain 2025.10 or earlier. " + "You can install it with `pip install metatrain==2025.10`." + ) + + +def trainer_update_v5_v6(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 5 to version 6. + + :param checkpoint: The checkpoint to update. + """ + # num_workers=0 means that the main process will do the data loading, which is + # equivalent to not setting it (this was the behavior before v6) + checkpoint["train_hypers"]["num_workers"] = 0 + + +def trainer_update_v6_v7(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 6 to version 7. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["fixed_scaling_weights"] = {} + + +def trainer_update_v7_v8(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # remove all entries in the loss `sliding_factor` + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + # retain everything except sliding_factor for each target + new_loss_hypers[target_name] = { + k: v + for k, v in old_loss_hypers[target_name].items() + if k != "sliding_factor" + } + + checkpoint["train_hypers"]["loss"] = new_loss_hypers + + +def trainer_update_v8_v9(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 7 to version 8. + + :param checkpoint: The checkpoint to update. + """ + # Adding the empty finetune config if not present + if "finetune" not in checkpoint["train_hypers"]: + checkpoint["train_hypers"]["finetune"] = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + + +def trainer_update_v9_v10(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 9 to version 10. + + :param checkpoint: The checkpoint to update. + """ + # Ensuring that the finetune read_from is None if not specified + checkpoint["train_hypers"]["remove_composition_contribution"] = True + + +def trainer_update_v10_v11(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 10 to version 11. + + :param checkpoint: The checkpoint to update. + """ + # - Remove the ``remove_composition_contribution`` hyper. + # - Rename ``fixed_composition_weights`` to ``atomic_baseline``. + # - If ``remove_composition_contribution`` is False, set all atomic baselines + # to 0.0 for all targets. + use_atomic_baseline = checkpoint["train_hypers"].pop( + "remove_composition_contribution" + ) + atomic_baseline = checkpoint["train_hypers"].pop("fixed_composition_weights") + + if not use_atomic_baseline: + # Just set + dataset_info = checkpoint["model_data"]["dataset_info"] + atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} + + checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] diff --git a/src/metatrain/pet_9/documentation.py b/src/metatrain/pet_9/documentation.py new file mode 100644 index 0000000000..a563ddb74c --- /dev/null +++ b/src/metatrain/pet_9/documentation.py @@ -0,0 +1,252 @@ +""" +PET +=== + +PET is a cleaner, more user-friendly reimplementation of the original +PET model :footcite:p:`pozdnyakov_smooth_2023`. It is designed for better +modularity and maintainability, while preseving compatibility with the original +PET implementation in ``metatrain``. It also adds new features like long-range +features, better fine-tuning implementation, a possibility to train on +arbitrarty targets, and a faster inference due to the ``fast attention``. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific dataset. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Since seeing them +for the first time might be overwhelming, here we provide a **list of the +parameters that are in general the most important** (in decreasing order +of importance): + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_neighbors_adaptive + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_pet + :no-index: + + .. autoattribute:: {{model_hypers_path}}.d_node + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_gnn_layers + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_attention_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.loss + :no-index: + + .. autoattribute:: {{model_hypers_path}}.long_range + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.pet_9.modules.finetuning import FinetuneHypers, NoFinetuneHypers +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.long_range import LongRangeHypers +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class ModelHypers(TypedDict): + """Hyperparameters for the PET model.""" + + cutoff: float = 4.5 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + num_neighbors_adaptive: Optional[int] = None + """Target number of neighbors for the adaptive cutoff scheme. + + This parameter activates the adaptive cutoff functionality. + Each atomic environments has a different cutoff, that is chosen + such that the number of neighbors is approximately equal to this + value. This can be useful to have a more uniform number of neighbors + per atom, especially in sparse systems. Setting it to None disables + this feature and uses all neighbors within the fixed cutoff radius. + """ + cutoff_function: Literal["Cosine", "Bump"] = "Bump" + """Type of the smoothing function at the cutoff""" + cutoff_width: float = 0.5 + """Width of the smoothing function at the cutoff""" + d_pet: int = 128 + """Dimension of the edge features. + + This hyperparameters controls width of the neural network. In general, + increasing it might lead to better accuracy, especially on larger datasets, at the + cost of increased training and evaluation time. + """ + d_head: int = 128 + """Dimension of the attention heads.""" + d_node: int = 256 + """Dimension of the node features. + + Increasing this hyperparameter might lead to better accuracy, + with a relatively small increase in inference time. + """ + d_feedforward: int = 256 + """Dimension of the feedforward network in the attention layer.""" + num_heads: int = 8 + """Attention heads per attention layer.""" + num_attention_layers: int = 2 + """The number of attention layers in each layer of the graph + neural network. Depending on the dataset, increasing this hyperparameter might + lead to better accuracy, at the cost of increased training and evaluation time. + """ + num_gnn_layers: int = 2 + """The number of graph neural network layers. + + In general, decreasing this hyperparameter to 1 will lead to much faster models, + at the expense of accuracy. Increasing it may or may not lead to better accuracy, + depending on the dataset, at the cost of increased training and evaluation time. + """ + normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm" + """Layer normalization type.""" + activation: Literal["SiLU", "SwiGLU"] = "SwiGLU" + """Activation function.""" + attention_temperature: float = 1.0 + """The temperature scaling factor for attention scores.""" + transformer_type: Literal["PreLN", "PostLN"] = "PreLN" + """The order in which the layer normalization and attention + are applied in a transformer block. Available options are ``PreLN`` + (normalization before attention) and ``PostLN`` (normalization after attention).""" + featurizer_type: Literal["residual", "feedforward"] = "feedforward" + """Implementation of the featurizer of the model to use. Available + options are ``residual`` (the original featurizer from the PET paper, that uses + residual connections at each GNN layer for readout) and ``feedforward`` (a modern + version that uses the last representation after all GNN iterations for readout). + Additionally, the feedforward version uses bidirectional features flow during the + message passing iterations, that favors features flowing from atom ``i`` to atom + ``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.""" + zbl: bool = False + """Use ZBL potential for short-range repulsion""" + long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) + """Long-range Coulomb interactions parameters.""" + num_experts: int = 1 + + +class TrainerHypers(TypedDict): + """Hyperparameters for training PET models.""" + + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for distributed communication among processes""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + warmup_fraction: float = 0.01 + """Fraction of training steps used for learning rate warmup.""" + learning_rate: float = 1e-4 + """Learning rate.""" + weight_decay: Optional[float] = None + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + """ + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_scaling_weights: FixedScalerWeights = {} + """Weights for target scaling. + + This is passed to the ``fixed_weights`` argument of + :meth:`Scaler.train_model `, + see its documentation to understand exactly what to pass here. + """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set + automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value.""" + loss: str | dict[str, LossSpecification | str] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" + batch_atom_bounds: list[Optional[int]] = [None, None] + """Bounds for the number of atoms per batch as [min, max]. Batches with atom + counts outside these bounds will be skipped during training. Use ``None`` for + either value to disable that bound. This is useful for preventing out-of-memory + errors and ensuring consistent computational load. Default: ``[None, None]``.""" + + finetune: NoFinetuneHypers | FinetuneHypers = { + "read_from": None, + "method": "full", + "config": {}, + "inherit_heads": {}, + } + """Parameters for fine-tuning trained PET models. + + See :ref:`label_fine_tuning_concept` for more details. + """ diff --git a/src/metatrain/pet_9/model.py b/src/metatrain/pet_9/model.py new file mode 100644 index 0000000000..e3370ec5a6 --- /dev/null +++ b/src/metatrain/pet_9/model.py @@ -0,0 +1,1659 @@ +import logging +import typing +import warnings +from math import prod +from typing import Any, Dict, List, Literal, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) +from torch.utils.hooks import RemovableHandle + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers +from .modules.finetuning import apply_finetuning_strategy +from .modules.readout import ReadoutLayer +from .modules.structures import get_pair_sample_labels, systems_to_batch +from .modules.transformer import CartesianTransformer + + +AVAILABLE_FEATURIZERS = typing.get_args(ModelHypers.__annotations__["featurizer_type"]) + + +class PET(ModelInterface[ModelHypers]): + """ + Metatrain-native implementation of the PET architecture. + + Originally proposed in work (https://arxiv.org/abs/2305.19302v3), + and published in the `pet` package (https://github.com/spozdn/pet). + + :param hypers: Hyperparameters for the PET model. See the documentation for details. + :param dataset_info: Information about the dataset, including atomic types and + targets. + """ + + __checkpoint_version__ = 11 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={"architecture": ["https://arxiv.org/abs/2305.19302v3"]} + ) + component_labels: Dict[str, List[List[Labels]]] + NUM_FEATURE_TYPES: int = 2 # node + edge features + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # Cache frequently accessed hyperparameters + self.cutoff = float(self.hypers["cutoff"]) + self.cutoff_function = self.hypers["cutoff_function"] + self.cutoff_width = float(self.hypers["cutoff_width"]) + self.num_neighbors_adaptive = ( + float(self.hypers["num_neighbors_adaptive"]) + if self.hypers["num_neighbors_adaptive"] is not None + else None + ) + self.d_pet = self.hypers["d_pet"] + self.d_node = self.hypers["d_node"] + self.d_head = self.hypers["d_head"] + self.d_feedforward = self.hypers["d_feedforward"] + self.num_heads = self.hypers["num_heads"] + self.num_gnn_layers = self.hypers["num_gnn_layers"] + self.num_attention_layers = self.hypers["num_attention_layers"] + self.num_experts = self.hypers["num_experts"] + self.normalization = self.hypers["normalization"] + self.activation = self.hypers["activation"] + self.attention_temperature = self.hypers["attention_temperature"] + self.transformer_type = self.hypers["transformer_type"] + self.featurizer_type = self.hypers["featurizer_type"] + + self.atomic_types = dataset_info.atomic_types + self.requested_nl = NeighborListOptions( + cutoff=self.cutoff, + full_list=True, + strict=True, + ) + num_atomic_species = len(self.atomic_types) + self.gnn_layers = torch.nn.ModuleList( + [ + CartesianTransformer( + self.cutoff, + self.cutoff_width, + self.d_pet, + self.num_heads, + self.d_node, + self.d_feedforward, + self.num_attention_layers, + self.normalization, + self.activation, + self.attention_temperature, + self.transformer_type, + num_atomic_species, + layer_index == 0, # is first layer + ) + for layer_index in range(self.num_gnn_layers) + ] + ) + if self.featurizer_type not in AVAILABLE_FEATURIZERS: + raise ValueError( + f"Unknown featurizer type: {self.featurizer_type}. " + f"Available options are: {AVAILABLE_FEATURIZERS}" + ) + if self.featurizer_type == "feedforward": + self.num_readout_layers = 1 + self.combination_norms = torch.nn.ModuleList( + [torch.nn.LayerNorm(2 * self.d_pet) for _ in range(self.num_gnn_layers)] + ) + self.combination_mlps = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(2 * self.d_pet, 2 * self.d_pet), + torch.nn.SiLU(), + torch.nn.Linear(2 * self.d_pet, self.d_pet), + ) + for _ in range(self.num_gnn_layers) + ] + ) + else: + self.num_readout_layers = self.num_gnn_layers + self.combination_norms = torch.nn.ModuleList() + self.combination_mlps = torch.nn.ModuleList() + + self.node_embedders = torch.nn.ModuleList( + [ + torch.nn.Embedding(num_atomic_species, self.d_node) + for _ in range(self.num_readout_layers) + ] + ) + self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + + self.node_heads = torch.nn.ModuleDict() + self.edge_heads = torch.nn.ModuleDict() + self.node_last_layers = torch.nn.ModuleDict() + self.edge_last_layers = torch.nn.ModuleDict() + self.last_layer_feature_size = ( + self.num_readout_layers * self.d_head * self.NUM_FEATURE_TYPES + ) # for LLPR + + # the model is always capable of outputting the internal features + self.outputs = { + "features": ModelOutput(per_atom=True, description="internal features") + } + + self.output_shapes: Dict[str, Dict[str, List[int]]] = {} + self.key_labels: Dict[str, Labels] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.target_names: List[str] = [] + self.last_layer_parameter_names: Dict[str, List[str]] = {} # for LLPR + for target_name, target_info in dataset_info.targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target_info) + + self.register_buffer( + "species_to_species_index", + torch.full((max(self.atomic_types) + 1,), -1), + ) + for i, species in enumerate(self.atomic_types): + self.species_to_species_index[species] = i + + # long-range module + if self.hypers["long_range"]["enable"]: + self.long_range = True + if not self.hypers["long_range"]["use_ewald"]: + warnings.warn( + "Training PET with the LongRangeFeaturizer initialized " + "with `use_ewald=False` causes instabilities during training. " + "The `use_ewald` variable will be force-switched to `True`. " + "during training.", + UserWarning, + stacklevel=2, + ) + self.long_range_featurizer = LongRangeFeaturizer( + hypers=self.hypers["long_range"], + feature_dim=self.d_node, + neighbor_list_options=self.requested_nl, + ) + else: + self.long_range = False + self.long_range_featurizer = DummyLongRangeFeaturizer() # for torchscript + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + + # Adds the ZBL repulsion model if requested + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + self.finetune_config: Dict[str, Any] = {} + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PET": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PET model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self.target_names.append(target_name) + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0] = self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler = self.scaler.restart(dataset_info) + + return self + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """ + Forward pass of the PET model. + + The forward pass processes atomic systems through multiple stages to produce + predictions for the requested outputs. The computation follows a graph neural + network architecture with attention-based message passing. + + **Stage 0: Input Preparation** + + The input systems are first converted into a batched representation containing: + + - `element_indices_nodes` [n_atoms]: Atomic species of the central atoms + - `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of + neighboring atoms + - `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors + between central atoms and their neighbors + - `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded + neighbors + - `reverse_neighbor_index` [n_atoms * max_num_neighbors]: Index of the ji edge + for each ij edge, once the edges are flattened into an array whose first + dimension is n_atoms * max_num_neighbors + - `system_indices` [n_atoms]: System index for each central atom + - `sample_labels` [n_atoms, 2]: Metatensor Labels containing indices of each + atom in each system + + **Stage 1: Feature Computation via GNN Layers** + + Node and edge representations are computed by iterating through the GNN layers + following one of two featurization strategies: + + - **Feedforward featurization**: Features are propagated through all + `num_gnn_layers` GNN layers sequentially, using only the final layer outputs + for readout. At each layer, forward and reversed edge messages are combined + using combination MLPs to enable bidirectional information flow. + + - **Residual featurization**: Intermediate node and edge features from each + GNN layer are saved and used during readout. Edge messages between layers are + averaged to maintain information from all hops. + + During this stage, the model: + + - Embeds atomic species into learned node and edge representations + - Applies Cartesian transformer layers to update features via attention + - Uses reversed neighbor lists to enable bidirectional message passing, where + the new input message from atom `j` to atom `i` in GNN layer N+1 is the + reversed message from atom `i` to atom `j` in GNN layer N + - Applies cutoff functions to weight interactions by distance + + If the long-range module is enabled, electrostatic features computed via Ewald + summation (during training) or Particle-Particle Particle Mesh Ewald (P3M) + (during evaluation) are added to the node features from each GNN layer. + + **Stage 2: Intermediate Feature Output (Optional)** + + If "features" is requested in the outputs, node and edge features from all + layers are concatenated to produce intermediate representations. Edge features + are summed over neighbors with cutoff weighting to obtain per-node + contributions. This output can be used for transfer learning or analysis. + + **Stage 3: Last Layer Feature Computation** + + For each requested output, output-specific heads (shallow MLPs with two linear + layers and SiLU activations) are applied to both node and edge features from + each GNN layer. This produces last layer features that are specialized for each + prediction target. These features can be optionally returned as auxiliary + outputs (e.g., "mtt::aux::energy_last_layer_features") for analysis or + transfer learning. + + **Stage 4: Atomic Predictions** + + Final linear layers are applied to the last layer features to produce per-atom + predictions for each requested output: + + - Node and edge last layer features are processed through separate linear + layers for each output block + - Contributions from all GNN layers are summed + - Edge contributions are summed over neighbors with cutoff weighting + - For rank-2 Cartesian tensors (e.g., stress), predictions are symmetrized and + normalized by cell volume + - Multiple tensor blocks per output are handled independently + + **Post-processing (Evaluation Only)** + + During evaluation (not training), the following transformations are applied: + + 1. **Scaling**: Predictions are scaled using learned or configured scale + factors + 2. **Additive contributions**: Composition model and optional ZBL repulsion + contributions are added to the predictions + + :param systems: List of `metatomic.torch.System` objects to process. Each + system should contain atomic positions, species, and cell information, with + neighbor lists computed according to `requested_neighbor_lists()`. + :param outputs: Dictionary of requested outputs in the format + {output_name: ModelOutput(...)}. The model supports: + + - Target properties (energy, forces, stress, etc.) + - "features": intermediate representations from Stage 2 + - Auxiliary last layer features (e.g., + "mtt::aux::energy_last_layer_features") + + :param selected_atoms: Optional `metatensor.torch.Labels` object specifying a + subset of atoms for which to compute outputs. If `None`, all atoms are + included. This is useful for computing properties for specific atomic + environments. + :return: Dictionary of `metatensor.torch.TensorMap` objects containing the + requested outputs. Each TensorMap contains per-atom or per-structure + predictions (depending on the ModelOutput configuration) with appropriate + metatensor metadata (samples, components, properties). + """ + device = systems[0].device + return_dict: Dict[str, TensorMap] = {} + nl_options = self.requested_neighbor_lists()[0] + + if self.single_label.values.device != device: + self._move_labels_to_device(device) + + # **Stage 0: Input Preparation** + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) = systems_to_batch( + systems, + nl_options, + self.atomic_types, + self.species_to_species_index, + self.cutoff_function, + self.cutoff_width, + self.num_neighbors_adaptive, + selected_atoms, + ) + + pair_sample_labels = get_pair_sample_labels( + systems, sample_labels, nl_options, device + ) + + # Optional diagnostic token capture: register temporary module hooks + diagnostic_handles = torch.jit.annotate(List[Any], []) + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + diagnostic_handles = self._prepare_diagnostic_handles( + outputs, + return_dict, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + + # the scaled_dot_product_attention function from torch cannot do + # double backward, so we will use manual attention if needed + use_manual_attention = edge_vectors.requires_grad and self.training + + with torch.profiler.record_function("PET::_calculate_features"): + # **Stage 1: Feature Computation via GNN Layers** + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + for featurizer_input_name, tensor in featurizer_inputs.items(): + if "mtt::features::" + featurizer_input_name in outputs: + return_dict["mtt::features::" + featurizer_input_name] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, + use_manual_attention=use_manual_attention, + ) + + # If the long-range module is activated, we add the long-range features + # on top of the node features + + if self.long_range: + long_range_features = self._calculate_long_range_features( + systems, node_features_list, edge_distances, padding_mask + ) + for i in range(self.num_readout_layers): + node_features_list[i] = ( + node_features_list[i] + long_range_features + ) * 0.5**0.5 + + # **Stage 2: Intermediate Feature Output (Optional)** + with torch.profiler.record_function("PET::_get_output_features"): + if "features" in outputs: + features_dict = self._get_output_features( + node_features_list, + edge_features_list, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + # Since return_dict.update(features_dict) is not Torch-Scriptable, + # we use a simple iteration over the features_dict items. + for k, v in features_dict.items(): + return_dict[k] = v + + # **Stage 3: Last Layer Feature Computation** + with torch.profiler.record_function("PET::_calculate_last_layer_features"): + node_last_layer_features_dict, edge_last_layer_features_dict = ( + self._calculate_last_layer_features( + node_features_list, + edge_features_list, + ) + ) + last_layer_features_dict = self._get_output_last_layer_features( + node_last_layer_features_dict, + edge_last_layer_features_dict, + cutoff_factors, + selected_atoms, + sample_labels, + outputs, + ) + + for k, v in last_layer_features_dict.items(): + return_dict[k] = v + + # **Stage 4: Atomic Predictions** + with torch.profiler.record_function("PET::_calculate_atomic_predictions"): + node_atomic_predictions_dict, edge_atomic_predictions_dict = ( + self._calculate_atomic_predictions( + systems, + node_last_layer_features_dict, + edge_last_layer_features_dict, + padding_mask, + cutoff_factors, + outputs, + ) + ) + atomic_predictions_dict = self._get_output_atomic_predictions( + systems, + node_atomic_predictions_dict, + edge_atomic_predictions_dict, + edge_vectors, + system_indices, + sample_labels, + outputs, + selected_atoms, + ) + + for k, v in atomic_predictions_dict.items(): + return_dict[k] = v + + # **Post-processing (Evaluation Only)** + + with torch.profiler.record_function("PET::post-processing"): + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler( + systems, return_dict, selected_atoms=selected_atoms + ) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap( + return_dict[name].keys, output_blocks + ) + + # remove any diagnostic hooks we registered and attach tokens + if (not torch.jit.is_scripting()) and (not torch.jit.is_tracing()): + for h in diagnostic_handles: + h.remove() + + return return_dict + + def _create_diagnostic_feature_tensormap( + self, + tensor: torch.Tensor, + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> TensorMap: + assert tensor.shape[0] == sample_labels.values.shape[0], ( + "diagnostic feature tensor must be per-atom or per-pair like in shape." + f" Got tensor.shape = {tensor.shape}, " + ) + + device = tensor.device + + outp = tensor.detach().clone() + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + if outp.shape[1] == 1: # node-like, shape (n_atoms, 1, d) + outp = outp.squeeze(1) + labels = sample_labels + + else: # edge-like, shape (n_atoms, num_neighbors, d) + outp = outp[centers, nef_to_edges_neighbor] + labels = pair_sample_labels + + if outp.ndim == 1: # can happen if d == 1 + outp = outp.unsqueeze(1) + + return TensorMap( + Labels(["_"], torch.tensor([[0]])).to(device=device), + [ + TensorBlock( + values=outp, + samples=labels.to(device=device), + components=[], + properties=Labels( + ["_"], + torch.arange(outp.shape[1]).reshape(-1, 1), + ).to(device=device), + ) + ], + ) + + @torch.jit.ignore + def _prepare_diagnostic_handles( + self, + outputs: Dict[str, ModelOutput], + return_dict: Dict[str, Any], + centers: torch.Tensor, + nef_to_edges_neighbor: torch.Tensor, + sample_labels: Labels, + pair_sample_labels: Labels, + ) -> List[Any]: + """ + Prepare forward hooks to capture diagnostic tokens from internal modules. + + :param outputs: Dictionary of requested outputs. + :param return_dict: Dictionary to store captured tokens. + :param centers: Tensor mapping center atoms to their indices. + :param nef_to_edges_neighbor: Tensor mapping neighbor edges to their indices. + :param sample_labels: Labels for individual atoms. + :param pair_sample_labels: Labels for atom pairs. + + :return: List of removable handles for the registered hooks. + """ + diagnostic_handles = [] + + def _resolve_module(path: str) -> Any: + obj: Any = self + for part in path.split("."): + if part.isdigit(): + obj = obj[int(part)] + else: + if not hasattr(obj, part): + raise AttributeError( + f"Module path '{path}' not found at '{part}'" + ) + obj = getattr(obj, part) + return obj + + # Build list of possible module paths that can be captured for these model + # hypers. + possible_capture_paths: List[str] = [] + + for i in range(self.num_readout_layers): + possible_capture_paths.append(f"node_embedders.{i}") + possible_capture_paths.append("edge_embedder") + + for i in range(self.num_gnn_layers): + # Total GNN layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}_node") + possible_capture_paths.append(f"gnn_layers.{i}_edge") + + # Finer-grained: embeddings and compressions + possible_capture_paths.append(f"gnn_layers.{i}.edge_embedder") + if i > 0: + possible_capture_paths.append(f"gnn_layers.{i}.neighbor_embedder") + possible_capture_paths.append(f"gnn_layers.{i}.compress") + + # transformer layers + for j in range(self.num_attention_layers): + # Transformer layer: special case as returns both node and edge features + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_node") + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}_edge") + + # Finer-grained: attention and MLP submodules + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.attention" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_contraction" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.center_mlp" + ) + possible_capture_paths.append( + f"gnn_layers.{i}.trans.layers.{j}.norm_mlp" + ) + possible_capture_paths.append(f"gnn_layers.{i}.trans.layers.{j}.mlp") + + # TODO: this is subtle, depending on whether the featurizer is feedforward or + # residual. To fix later. + # possible_capture_paths.append(f"combination_norms") + # possible_capture_paths.append(f"combination_mlps") + + for path in possible_capture_paths: + if "mtt::features::" + path not in outputs: + continue + + if "_node" in path: + suffix = "_node" + path = path.replace(suffix, "") + elif "_edge" in path: + suffix = "_edge" + path = path.replace(suffix, "") + else: + suffix = "" + + module = _resolve_module(path) + + def make_hook(p: str, suffix: str) -> Any: + def _hook(module: torch.nn.Module, inp: Any, outp: Any) -> None: + if isinstance(outp, tuple): + assert "_node" in suffix or "_edge" in suffix, ( + "When capturing from a module that returns multiple " + "outputs, the requested output must carry the suffix " + "'_node' or '_edge'." + ) + if suffix == "_node": + tensor = outp[0] + else: + assert suffix == "_edge" + tensor = outp[1] + + return_dict[f"mtt::features::{p}{suffix}"] = ( + self._create_diagnostic_feature_tensormap( + tensor, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + else: + return_dict["mtt::features::" + p] = ( + self._create_diagnostic_feature_tensormap( + outp, + centers, + nef_to_edges_neighbor, + sample_labels, + pair_sample_labels, + ) + ) + + return _hook + + handle = module.register_forward_hook(make_hook(path, suffix)) + diagnostic_handles.append(handle) + + return diagnostic_handles + + def _calculate_features( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Calculate node and edge features using the selected featurization strategy. + Returns lists of feature tensors from GNN layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors + - List of edge feature tensors + In the case of feedforward featurization, each list contains a single tensor + from the final GNN layer. In the case of residual featurization, each list + contains tensors from all GNN layers. + """ + if self.featurizer_type == "feedforward": + return self._feedforward_featurization_impl(inputs, use_manual_attention) + else: + return self._residual_featurization_impl(inputs, use_manual_attention) + + def _feedforward_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Feedforward featurization: iterates features through all GNN layers, + returning only the final layer outputs. Uses combination MLPs to mix + forward and reversed edge messages at each layer. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from the final GNN layer + - List of edge feature tensors from the final GNN layer + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + + input_node_embeddings = self.node_embedders[0](inputs["element_indices_nodes"]) + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for combination_norm, combination_mlp, gnn_layer in zip( + self.combination_norms, self.combination_mlps, self.gnn_layers, strict=True + ): + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + input_node_embeddings = output_node_embeddings + new_input_edge_embeddings = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + # input_messages = 0.5 * (output_edge_embeddings + new_input_messages) + concatenated = torch.cat( + [output_edge_embeddings, new_input_edge_embeddings], dim=-1 + ) + input_edge_embeddings = ( + input_edge_embeddings + + output_edge_embeddings + + combination_mlp(combination_norm(concatenated)) + ) + + node_features_list.append(input_node_embeddings) + edge_features_list.append(input_edge_embeddings) + return node_features_list, edge_features_list + + def _residual_featurization_impl( + self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Residual featurization: saves intermediate features from each GNN layer + for use in readout. Averages forward and reversed edge messages between layers. + + :param inputs: Dictionary containing input tensors required for feature + computation + :param use_manual_attention: Whether to use manual attention computation + (required for double backward when edge vectors require gradients) + :return: Tuple of two lists: + - List of node feature tensors from all GNN layers + - List of edge feature tensors from all GNN layers + """ + node_features_list: List[torch.Tensor] = [] + edge_features_list: List[torch.Tensor] = [] + input_edge_embeddings = self.edge_embedder(inputs["element_indices_neighbors"]) + for node_embedder, gnn_layer in zip( + self.node_embedders, self.gnn_layers, strict=True + ): + input_node_embeddings = node_embedder(inputs["element_indices_nodes"]) + output_node_embeddings, output_edge_embeddings = gnn_layer( + input_node_embeddings, + input_edge_embeddings, + inputs["element_indices_neighbors"], + inputs["edge_vectors"], + inputs["padding_mask"], + inputs["edge_distances"], + inputs["cutoff_factors"], + use_manual_attention, + ) + node_features_list.append(output_node_embeddings) + edge_features_list.append(output_edge_embeddings) + + # The GNN contraction happens by reordering the messages, + # using a reversed neighbor list, so the new input message + # from atom `j` to atom `i` in on the GNN layer N+1 is a + # reversed message from atom `i` to atom `j` on the GNN layer N. + # (Flatten, index, and reshape to the original shape) + new_input_messages = output_edge_embeddings.reshape( + output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + )[inputs["reverse_neighbor_index"]].reshape( + output_edge_embeddings.shape[0], + output_edge_embeddings.shape[1], + output_edge_embeddings.shape[2], + ) + input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages) + return node_features_list, edge_features_list + + def _calculate_long_range_features( + self, + systems: List[System], + node_features_list: List[torch.Tensor], + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate long-range electrostatic features using Ewald summation. + Forces use_ewald=True during training for stability. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_distances: Tensor of edge distances [n_atoms, max_num_neighbors]. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :return: Tensor of long-range features [n_atoms, d_pet]. + """ + if self.training: + # Currently, the long-range implementation show instabilities + # during training if P3MCalculator is used instead of the + # EwaldCalculator. We will use the EwaldCalculator for training. + self.long_range_featurizer.use_ewald = True + flattened_lengths = edge_distances[padding_mask] + short_range_features = ( + torch.stack(node_features_list).sum(dim=0) + * (1 / len(node_features_list)) ** 0.5 + ) + long_range_features = self.long_range_featurizer( + systems, short_range_features, flattened_lengths + ) + return long_range_features + + def _get_output_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Concatenate node and edge features from all layers into intermediate + feature representations. Edge features are summed with cutoff weighting. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping "features" to a TensorMap of intermediate + representations, either per-atom or summed over atoms. + """ + features_dict: Dict[str, TensorMap] = {} + node_features = torch.cat(node_features_list, dim=1) + edge_features = torch.cat(edge_features_list, dim=2) + edge_features = (edge_features * cutoff_factors[:, :, None]).sum(dim=1) + features = torch.cat([node_features, edge_features], dim=1) + + feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + features.shape[-1], device=features.device + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + feature_tmap = mts.slice( + feature_tmap, + axis="samples", + selection=selected_atoms, + ) + if requested_outputs["features"].per_atom: + features_dict["features"] = feature_tmap + else: + features_dict["features"] = sum_over_atoms(feature_tmap) + return features_dict + + def _calculate_last_layer_features( + self, + node_features_list: List[torch.Tensor], + edge_features_list: List[torch.Tensor], + ) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + """ + Apply output-specific heads to node and edge features from each GNN layer. + Returns dictionaries mapping output names to lists of head-transformed features. + + :param node_features_list: List of node feature tensors from each GNN layer. + :param edge_features_list: List of edge feature tensors from each GNN layer. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of node last layer features + - Dictionary mapping output names to lists of edge last layer features + """ + node_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + + # Calculating node last layer features + for output_name, node_heads in self.node_heads.items(): + if output_name not in node_last_layer_features_dict: + node_last_layer_features_dict[output_name] = [] + for i, node_head in enumerate(node_heads): + node_last_layer_features_dict[output_name].append( + node_head(node_features_list[i]) + ) + + # Calculating edge last layer features + for output_name, edge_heads in self.edge_heads.items(): + if output_name not in edge_last_layer_features_dict: + edge_last_layer_features_dict[output_name] = [] + for i, edge_head in enumerate(edge_heads): + edge_last_layer_features_dict[output_name].append( + edge_head(edge_features_list[i]) + ) + + return node_last_layer_features_dict, edge_last_layer_features_dict + + def _get_output_last_layer_features( + self, + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + cutoff_factors: torch.Tensor, + selected_atoms: Optional[Labels], + sample_labels: Labels, + requested_outputs: Dict[str, ModelOutput], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge last layer features for requested last layer + features output. Edge features are summed with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param requested_outputs: Dictionary of requested outputs. + :return: Dictionary mapping requested last layer features output names + to TensorMaps of last layer features, either per-atom or summed over atoms. + """ + last_layer_features_dict: Dict[str, List[torch.Tensor]] = {} + last_layer_features_outputs: Dict[str, TensorMap] = {} + for output_name in node_last_layer_features_dict.keys(): + if not should_compute_last_layer_features(output_name, requested_outputs): + continue + if output_name not in last_layer_features_dict: + last_layer_features_dict[output_name] = [] + for i in range(len(node_last_layer_features_dict[output_name])): + node_last_layer_features = node_last_layer_features_dict[output_name][i] + edge_last_layer_features = edge_last_layer_features_dict[output_name][i] + edge_last_layer_features = ( + edge_last_layer_features * cutoff_factors[:, :, None] + ).sum(dim=1) + last_layer_features_dict[output_name].append(node_last_layer_features) + last_layer_features_dict[output_name].append(edge_last_layer_features) + + for output_name in requested_outputs: + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in last_layer_features_dict: + base_name = f"mtt::{base_name}" + last_layer_features_values = torch.cat( + last_layer_features_dict[base_name], dim=1 + ) + last_layer_feature_tmap = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=last_layer_features_values, + samples=sample_labels, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange( + last_layer_features_values.shape[-1], + device=last_layer_features_values.device, + ).reshape(-1, 1), + assume_unique=True, + ), + ) + ], + ) + if selected_atoms is not None: + last_layer_feature_tmap = mts.slice( + last_layer_feature_tmap, + axis="samples", + selection=selected_atoms, + ) + last_layer_features_options = requested_outputs[output_name] + if last_layer_features_options.per_atom: + last_layer_features_outputs[output_name] = last_layer_feature_tmap + else: + last_layer_features_outputs[output_name] = sum_over_atoms( + last_layer_feature_tmap + ) + return last_layer_features_outputs + + def _calculate_atomic_predictions( + self, + systems: List[System], + node_last_layer_features_dict: Dict[str, List[torch.Tensor]], + edge_last_layer_features_dict: Dict[str, List[torch.Tensor]], + padding_mask: torch.Tensor, + cutoff_factors: torch.Tensor, + outputs: Dict[str, ModelOutput], + ) -> Tuple[ + Dict[str, List[List[torch.Tensor]]], Dict[str, List[List[torch.Tensor]]] + ]: + """ + Apply final linear layers to last layer features to produce + per-atom predictions. Handles multiple blocks per output and sums + edge contributions with cutoff weighting. + + :param node_last_layer_features_dict: Dictionary mapping output names to + lists of node last layer features. + :param edge_last_layer_features_dict: Dictionary mapping output names to + lists of edge last layer features. + :param padding_mask: Boolean mask indicating real vs padded neighbors + [n_atoms, max_num_neighbors]. + :param cutoff_factors: Tensor of cutoff factors for edge distances + [n_atoms, max_num_neighbors]. + :param outputs: Dictionary of requested outputs. + :return: Tuple of two dictionaries: + - Dictionary mapping output names to lists of lists of node atomic + prediction tensors (one list per GNN layer, one tensor per block) + - Dictionary mapping output names to lists of lists of edge atomic + prediction tensors (one list per GNN layer, one tensor per block) + """ + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]] = {} + + # Computing node atomic predictions. Since we have last layer features + # for each GNN layer, and each last layer can have multiple blocks, + # we apply each last layer block to each of the last layer features. + + batch_species_indices = self.species_to_species_index[ + torch.cat([system.types for system in systems], dim=0) + ] + + for output_name, node_last_layers in self.node_last_layers.items(): + if output_name in outputs: + node_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, node_last_layer in enumerate(node_last_layers): + node_last_layer_features = node_last_layer_features_dict[ + output_name + ][i] + node_atomic_predictions_by_block: List[torch.Tensor] = [] + for node_last_layer_by_block in node_last_layer.values(): + node_atomic_predictions_by_block.append( + node_last_layer_by_block( + batch_species_indices, + node_last_layer_features, + ) + ) + node_atomic_predictions_dict[output_name].append( + node_atomic_predictions_by_block + ) + + # Computing edge atomic predictions. Following the same logic as above, + # we (1) iterate over the last layer features and last layer blocks, and (2) + # sum the edge features with cutoff factors to get their per-node contribution. + + for output_name, edge_last_layers in self.edge_last_layers.items(): + if output_name in outputs: + edge_atomic_predictions_dict[output_name] = torch.jit.annotate( + List[List[torch.Tensor]], [] + ) + for i, edge_last_layer in enumerate(edge_last_layers): + edge_last_layer_features = edge_last_layer_features_dict[ + output_name + ][i] + edge_atomic_predictions_by_block: List[torch.Tensor] = [] + for edge_last_layer_by_block in edge_last_layer.values(): + edge_atomic_predictions = edge_last_layer_by_block( + batch_species_indices, edge_last_layer_features + ) + expanded_padding_mask = padding_mask[..., None].repeat( + 1, 1, edge_atomic_predictions.shape[2] + ) + edge_atomic_predictions = torch.where( + ~expanded_padding_mask, 0.0, edge_atomic_predictions + ) + edge_atomic_predictions_by_block.append( + (edge_atomic_predictions * cutoff_factors[:, :, None]).sum( + dim=1 + ) + ) + edge_atomic_predictions_dict[output_name].append( + edge_atomic_predictions_by_block + ) + + return node_atomic_predictions_dict, edge_atomic_predictions_dict + + def _get_output_atomic_predictions( + self, + systems: List[System], + node_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_atomic_predictions_dict: Dict[str, List[List[torch.Tensor]]], + edge_vectors: torch.Tensor, + system_indices: torch.Tensor, + sample_labels: Labels, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + """ + Combine node and edge atomic predictions into final TensorMaps. + Handles rank-2 Cartesian tensors by symmetrizing them. + Returns per-atom or per-structure predictions based on output configuration. + + :param systems: List of `metatomic.torch.System` objects to process. + :param node_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of node atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_atomic_predictions_dict: Dictionary mapping output names to + lists of lists of edge atomic prediction tensors (one list per GNN layer, + one tensor per block). + :param edge_vectors: Tensor of edge vectors [n_atoms, max_num_neighbors, 3]. + :param system_indices: Tensor mapping each atom to its system index + [n_atoms]. + :param sample_labels: Labels for all atoms in the batch [n_atoms, 2]. + :param outputs: Dictionary of requested outputs. + :param selected_atoms: Optional Labels specifying a subset of atoms to include. + :return: Dictionary mapping requested output names to TensorMaps of + predictions, either per-atom or summed over atoms. + """ + atomic_predictions_tmap_dict: Dict[str, TensorMap] = {} + for output_name in self.target_names: + if output_name in outputs: + atomic_predictions_by_block = { + key: torch.zeros( + 1, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + for key in self.output_shapes[output_name].keys() + } + + node_atomic_predictions_by_block = node_atomic_predictions_dict[ + output_name + ] + edge_atomic_predictions_by_block = edge_atomic_predictions_dict[ + output_name + ] + for i in range(len(node_atomic_predictions_by_block)): + node_atomic_prediction_block = node_atomic_predictions_by_block[i] + edge_atomic_prediction_block = edge_atomic_predictions_by_block[i] + for j, key in enumerate(atomic_predictions_by_block): + node_atomic_predictions = node_atomic_prediction_block[j] + edge_atomic_predictions = edge_atomic_prediction_block[j] + atomic_predictions_by_block[key] = atomic_predictions_by_block[ + key + ] + (node_atomic_predictions + edge_atomic_predictions) + + if output_name == "non_conservative_stress": # TODO: variants + block_key = list(atomic_predictions_by_block.keys())[0] + output_shapes_values = list( + self.output_shapes[output_name].values() + ) + num_properties = output_shapes_values[0][-1] + symmetrized = process_non_conservative_stress( + atomic_predictions_by_block[block_key], + systems, + system_indices, + num_properties, + ) + atomic_predictions_by_block[block_key] = symmetrized + + blocks = [ + TensorBlock( + values=atomic_predictions_by_block[key].reshape([-1] + shape), + samples=sample_labels, + components=components, + properties=properties, + ) + for key, shape, components, properties in zip( + self.output_shapes[output_name].keys(), + self.output_shapes[output_name].values(), + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ] + atomic_predictions_tmap_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=blocks, + ) + # If selected atoms request is provided, we slice the atomic predictions + # tensor maps to get the predictions for the selected atoms only. + + if selected_atoms is not None: + for output_name, tmap in atomic_predictions_tmap_dict.items(): + atomic_predictions_tmap_dict[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + # If per-atom predictions are requested, we return the atomic predictions + # tensor maps. Otherwise, we sum the atomic predictions over the atoms + # to get the final per-structure predictions for each requested output. + + for output_name, atomic_property in atomic_predictions_tmap_dict.items(): + if outputs[output_name].per_atom: + atomic_predictions_tmap_dict[output_name] = atomic_property + else: + atomic_predictions_tmap_dict[output_name] = sum_over_atoms( + atomic_property + ) + + return atomic_predictions_tmap_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PET": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + + finetune_config = model_state_dict.pop("finetune_config", {}) + if finetune_config: + # Apply the finetuning strategy + model = apply_finetuning_strategy(model, finetune_config) + state_dict_iter = iter(model_state_dict.values()) + next(state_dict_iter) # skip the species_to_species_index + dtype = next(state_dict_iter).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.num_gnn_layers * self.cutoff] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + """ + Register a new output target by creating corresponding heads and last layers. + Sets up node/edge heads and linear layers for all readout layers. + + :param target_name: Name of the target to add. + :param target_info: TargetInfo object containing details about the target. + """ + # one output shape for each tensor block, grouped by target (i.e. tensormap) + self.output_shapes[target_name] = {} + for key, block in target_info.layout.items(): + dict_key = target_name + for n, k in zip(key.names, key.values, strict=True): + dict_key += f"_{n}_{int(k)}" + self.output_shapes[target_name][dict_key] = [ + len(comp.values) for comp in block.components + ] + [len(block.properties.values)] + + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + description=target_info.description, + ) + + self.node_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_node, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_heads[target_name] = torch.nn.ModuleList( + [ + torch.nn.Sequential( + torch.nn.Linear(self.d_pet, self.d_head), + torch.nn.SiLU(), + torch.nn.Linear(self.d_head, self.d_head), + torch.nn.SiLU(), + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.node_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: ReadoutLayer( + self.d_head, + prod(shape), + len(self.atomic_types), + bias=True, + num_experts=self.num_experts, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + self.edge_last_layers[target_name] = torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + key: ReadoutLayer( + self.d_head, + prod(shape), + len(self.atomic_types), + bias=True, + num_experts=self.num_experts, + ) + for key, shape in self.output_shapes[target_name].items() + } + ) + for _ in range(self.num_readout_layers) + ] + ) + + # Register last-layer parameters, in the same order as they are returned as + # last-layer features in the model + self.last_layer_parameter_names[target_name] = [] + for layer_index in range(self.num_readout_layers): + for key in self.output_shapes[target_name].keys(): + self.last_layer_parameter_names[target_name].append( + f"node_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + self.last_layer_parameter_names[target_name].append( + f"edge_last_layers.{target_name}.{layer_index}.{key}.weight" + ) + + ll_features_name = get_last_layer_features_name(target_name) + self.outputs[ll_features_name] = ModelOutput( + per_atom=True, description=f"last layer features for {target_name}" + ) + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def _move_labels_to_device(self, device: torch.device) -> None: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "pet_9", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + +def process_non_conservative_stress( + tensor: torch.Tensor, + systems: List[System], + system_indices: torch.Tensor, + num_properties: int, +) -> torch.Tensor: + """ + Symmetrizes and normalizes by the volume rank-2 Cartesian tensors that are meant + to predict the non-conservative stress. + + :param tensor: Tensor of shape [n_atoms, 9 * num_properties]. + :param systems: List of `metatomic.torch.System` objects to process. + :param system_indices: Tensor mapping each atom to its system index [n_atoms]. + :param num_properties: Number of properties in the tensor (e.g., 6 for stress). + :return: Symmetrized tensor of shape [n_atoms, 3, 3, num_properties], divided by the + cell volume. + """ + # Reshape to 3x3 matrix per atom + tensor_as_three_by_three = tensor.reshape(-1, 3, 3, num_properties) + + # Normalize by cell volume + volumes = torch.stack([torch.abs(torch.det(system.cell)) for system in systems]) + # Zero volume can happen due to metatomic's convention of zero cell + # vectors for non-periodic directions. The actual volume is +inf + volumes[volumes == 0.0] = torch.inf + volumes_by_atom = volumes[system_indices].unsqueeze(1).unsqueeze(2).unsqueeze(3) + tensor_as_three_by_three = tensor_as_three_by_three / volumes_by_atom + + # Symmetrize + tensor_as_three_by_three = ( + tensor_as_three_by_three + tensor_as_three_by_three.transpose(1, 2) + ) / 2.0 + + return tensor_as_three_by_three + + +def get_last_layer_features_name(target_name: str) -> str: + """ + Get the auxiliary output name for last layer features of a target. + + :param target_name: Name of the target. + :return: Name of the corresponding last layer features output. + """ + base_name = target_name.replace("mtt::", "") + return f"mtt::aux::{base_name}_last_layer_features" + + +def should_compute_last_layer_features( + output_name: str, requested_outputs: Dict[str, ModelOutput] +) -> bool: + """ + Check if last layer features should be computed for an output. + + :param output_name: Name of the output to check. + :param requested_outputs: Dictionary of requested outputs. + :return: True if last layer features should be computed, False otherwise. + """ + if output_name in requested_outputs: + return True + ll_features_name = get_last_layer_features_name( + output_name.replace("mtt::aux::", "") + ) + return ll_features_name in requested_outputs diff --git a/src/metatrain/pet_9/modules/adaptive_cutoff.py b/src/metatrain/pet_9/modules/adaptive_cutoff.py new file mode 100644 index 0000000000..5f6206bdb1 --- /dev/null +++ b/src/metatrain/pet_9/modules/adaptive_cutoff.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch + +from .utilities import cutoff_func_bump as cutoff_func + + +# minimum value for the probe cutoff. this avoids getting too close +# to the central atom. in practice it could be also set to a larger value +DEFAULT_MIN_PROBE_CUTOFF = 0.5 +# recommended smooth cutoff width for effective neighbor number calculation +# smaller values lead to a more "step-like" behavior, but can be +# numerically unstable. in practice this will be called with the +# same cutoff as the main cutoff function +DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH = 1.0 + + +def get_adaptive_cutoffs( + centers: torch.Tensor, + edge_distances: torch.Tensor, + num_neighbors_adaptive: float, + num_nodes: int, + max_cutoff: float, + min_cutoff: float = DEFAULT_MIN_PROBE_CUTOFF, + cutoff_width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, + probe_spacing: Optional[float] = None, + weight_width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the adaptive cutoff values for each center atom. + + :param centers: Indices of the center atoms. + :param edge_distances: Distances between centers and their neighbors. + :param num_neighbors_adaptive: Target number of neighbors per atom. + :param num_nodes: Total number of center atoms. + :param max_cutoff: Maximum cutoff distance to consider. + :param min_cutoff: Minimum cutoff distance to consider. + :param cutoff_width: Width of the smooth cutoff taper region. + :param probe_spacing: Spacing between probe cutoffs. If None, it will be + automatically determined from the cutoff width. + :param weight_width: Width of the cutoff selection weight function. If None, it + will be automatically determined from the empirical neighbor counts. + :return: Adapted cutoff distances for each center atom. + """ + + # heuristic for the grid spacing of probe cutoffs, based on a + # the smoothness of the cutoff function + if probe_spacing is None: + probe_spacing = cutoff_width / 4.0 + probe_cutoffs = torch.arange( + min_cutoff, + max_cutoff, + probe_spacing, + device=edge_distances.device, + dtype=edge_distances.dtype, + ) + with torch.profiler.record_function("PET::get_effective_num_neighbors"): + effective_num_neighbors = get_effective_num_neighbors( + edge_distances, + probe_cutoffs, + centers, + num_nodes, + width=cutoff_width, + ) + + with torch.profiler.record_function("PET::get_cutoff_weights"): + cutoffs_weights = get_gaussian_cutoff_weights( + effective_num_neighbors, + num_neighbors_adaptive, + width=weight_width, + ) + with torch.profiler.record_function("PET::calculate_adapted_cutoffs"): + adapted_atomic_cutoffs = probe_cutoffs @ cutoffs_weights.T + return adapted_atomic_cutoffs + + +def get_effective_num_neighbors( + edge_distances: torch.Tensor, + probe_cutoffs: torch.Tensor, + centers: torch.Tensor, + num_nodes: int, + width: float = DEFAULT_EFFECTIVE_NUM_NEIGHBORS_WIDTH, +) -> torch.Tensor: + """ + Computes the effective number of neighbors for each probe cutoff. + + :param edge_distances: Distances between centers and their neighbors. + :param probe_cutoffs: Probe cutoff distances. + :param centers: Indices of the center atoms. + :param num_nodes: Total number of center atoms. + :param width: Width of the cutoff function. + :return: Effective number of neighbors for each center atom and probe cutoff. + """ + + weights = cutoff_func( + edge_distances.unsqueeze(0), probe_cutoffs.unsqueeze(1), width + ) + + probe_num_neighbors = torch.zeros( + (len(probe_cutoffs), num_nodes), + dtype=edge_distances.dtype, + device=edge_distances.device, + ) + # accumulate the weights for all probe cutoffs and center atoms at once + probe_num_neighbors.index_add_(1, centers, weights) + probe_num_neighbors = probe_num_neighbors.T + + return probe_num_neighbors + + +def get_gaussian_cutoff_weights( + effective_num_neighbors: torch.Tensor, + num_neighbors_adaptive: float, + width: Optional[float] = None, +) -> torch.Tensor: + """ + Computes the weights for each probe cutoff based on + the effective number of neighbors using Gaussian weights + centered at the expected number of neighbors. + + :param effective_num_neighbors: Effective number of neighbors for each center atom + and probe cutoff. + :param num_neighbors_adaptive: Target maximum number of neighbors per atom. + :param width: Width of the Gaussian cutoff selection function. + :return: Weights for each probe cutoff. + """ + diff = effective_num_neighbors - num_neighbors_adaptive + + # adds a "baseline" corresponding to uniformly-distributed atoms + # this has multiple "good" effects: it pushes the cutoff "out" when + # there are few neighbors, and "in" when there are many, and it + # stabilizes the weights with respect to variations in the neighbor + # distribution when there are empty ranges leading to "flat" + # neighbor count distribution + x = torch.linspace( + 0, + 1, + effective_num_neighbors.shape[1], + device=effective_num_neighbors.device, + dtype=effective_num_neighbors.dtype, + ) + baseline = num_neighbors_adaptive * x**3 + + diff = diff + baseline.unsqueeze(0) + if width is None: + # adaptive width from neighbor-count slope along probe axis (last dim) + eps = 1e-12 + if diff.shape[-1] == 1: + # Can't compute gradient from single point; use scaled diff as proxy + width_t = diff.abs() * 0.5 + eps + else: + # Compute numerical gradient: centered differences for interior, + # one-sided differences at boundaries + (width_t,) = torch.gradient(diff, dim=-1) + width_t = width_t.abs().clamp_min(eps) + else: + width_t = torch.ones_like(diff) * width + + logw = -0.5 * (diff / width_t) ** 2 + weights = torch.exp(logw - logw.max()) + + # row-wise normalization of the weights + weights_sum = weights.sum(dim=1, keepdim=True) + weights = weights / weights_sum + + return weights diff --git a/src/metatrain/pet_9/modules/finetuning.py b/src/metatrain/pet_9/modules/finetuning.py new file mode 100644 index 0000000000..27f2086f19 --- /dev/null +++ b/src/metatrain/pet_9/modules/finetuning.py @@ -0,0 +1,264 @@ +# mypy: disable-error-code=misc +# We ignore misc errors in this file because TypedDict +# with default values is not allowed by mypy. +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Literal, NotRequired, TypedDict + + +class LoRaFinetuneConfig(TypedDict): + """Configuration for LoRA finetuning strategy.""" + + rank: int + """Rank of the LoRA matrices.""" + alpha: float + """Scaling factor for the LoRA matrices.""" + target_modules: NotRequired[list[str]] + + +class HeadsFinetuneConfig(TypedDict): + """Configuration for heads finetuning strategy.""" + + head_modules: list[str] + """List of module name prefixes for the prediction heads to finetune.""" + last_layer_modules: list[str] + """List of module name prefixes for the last layers to finetune.""" + + +class NoFinetuneHypers(TypedDict): + """Hypers that indicate that no finetuning is to be applied.""" + + read_from: None = None + """No finetuning is indicated by setting this argument to None. + + The rest of finetuning hyperparameters are then ignored. + """ + method: NotRequired[Any] + config: NotRequired[Any] + inherit_heads: NotRequired[Any] + + +class FullFinetuneHypers(TypedDict): + """Hyperparameters to use full finetuning of PET models. + + This means all model parameters are trainable. + """ + + method: Literal["full"] = "full" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: NotRequired[Any] + """No configuration needed for full finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class LoRaFinetuneHypers(TypedDict): + """Hyperparameters for LoRA finetuning of PET models. + + Injects LoRA layers and finetunes only them. + """ + + method: Literal["lora"] = "lora" + """Finetuning method to use""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: LoRaFinetuneConfig + """Configuration for LoRA finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +class HeadsFinetuneHypers(TypedDict): + """Hyperparameters for heads finetuning of PET models. + + Freezes all model parameters except for the prediction heads + and last layers. + """ + + method: Literal["heads"] = "heads" + """Finetuning method to use.""" + read_from: str + """Path to the pretrained model checkpoint.""" + config: HeadsFinetuneConfig + """Configuration for heads finetuning.""" + inherit_heads: dict[str, str] = {} + """Mapping from new trainable targets (keys) to the existing targets + in the model (values). + This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization.""" + + +FinetuneHypers = FullFinetuneHypers | LoRaFinetuneHypers | HeadsFinetuneHypers + + +def apply_finetuning_strategy(model: nn.Module, strategy: FinetuneHypers) -> nn.Module: + """ + Apply the specified finetuning strategy to the model. + This function modifies the model in place based on the provided strategy. + + :param model: The model to be finetuned. + :param strategy: A dictionary specifying the finetuning strategy. + The strategy method can be one of the following: + - lora: Inject LoRA layers into the model, or reapply training if already + present. + - heads: Freeze all parameters except for the heads and last layers. + - full: All parameters are trainable. + Additionally, the strategy can include an "inherit_heads" key, + which is a dictionary mapping the new trainable targets to the existing + targets in the model. This allows for copying weights from the corresponding + source heads to the destination heads instead of random initialization. + :return: The modified model with the finetuning strategy applied. + """ + + for param in model.parameters(): + param.requires_grad = True + + if strategy["method"] == "full": + # Full finetuning, all parameters are trainable + pass + + elif strategy["method"] == "lora": + lora_config = strategy["config"] + lora_already_applied = any(isinstance(m, LoRALinear) for m in model.modules()) + if not lora_already_applied: + model_device = next(model.parameters()).device + model_dtype = next(model.parameters()).dtype + model = inject_lora_layers( + model, + target_modules=tuple( + lora_config.get("target_modules", ["input_linear", "output_linear"]) + ), + rank=lora_config.get("rank", 4), + alpha=lora_config.get("alpha", 8), + device=model_device, + dtype=model_dtype, + ) + + # Freeze all except LoRA + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + elif strategy["method"] == "heads": + heads_config = strategy.get( + "config", + { + "head_modules": ["node_heads", "edge_heads"], + "last_layer_modules": ["node_last_layers", "edge_last_layers"], + }, + ) + + head_keywords = heads_config.get("head_modules", []) + last_layer_keywords = heads_config.get("last_layer_modules", []) + + for name, param in model.named_parameters(): + if any(name.startswith(kw) for kw in head_keywords + last_layer_keywords): + param.requires_grad = True + else: + param.requires_grad = False + + else: + raise ValueError( + f"Unknown finetuning strategy: {strategy['method']}. Available methods " + "are: 'full', 'lora', 'heads'." + ) + + model.finetune_config = strategy + + inherit_heads_config = strategy["inherit_heads"] + if inherit_heads_config: + for dest_target_name, source_target_name in inherit_heads_config.items(): + model_parameters = dict(model.named_parameters()) + if not any(f".{source_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the source target name '{source_target_name}' was not found in " + "the model. Please specify the correct source target name." + ) + if not any(f".{dest_target_name}." in name for name in model_parameters): + raise ValueError( + f"Weight inheritance was selected in finetuning strategy, but " + f"the destination target name '{dest_target_name}' was not found " + "in the model. Please specify the correct destination target name." + ) + for name, param in model_parameters.items(): + if f".{source_target_name}." in name: + corresponding_dest_name = name.replace( + source_target_name, dest_target_name + ) + if corresponding_dest_name in model_parameters: + model_parameters[corresponding_dest_name].data.copy_(param.data) + else: + raise ValueError( + f"Destination head '{dest_target_name}' not found in model." + ) + return model + + +def inject_lora_layers( + model: nn.Module, + target_modules: Tuple[str, ...] = ("input_linear", "output_linear"), + rank: int = 4, + alpha: float = 1.0, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """ + Inject LoRA layers into the model. + This function replaces the specified linear layers in the model with + LoRALinear layers. + + :param model: The model to modify. + :param target_modules: A tuple of strings specifying the names of the attributes of + the modules to be replaced with LoRA layers. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + :param device: The device to which the LoRA layers should be moved. If None, the + LoRA layers will be on the same device as the original model. + :param dtype: The data type to which the LoRA layers should be converted. If + None, the LoRA layers will have the same dtype as the original model. + :return: The modified model with LoRA layers injected. + """ + for _, module in model.named_modules(): + for attr in target_modules: + if hasattr(module, attr): + linear = getattr(module, attr) + if isinstance(linear, nn.Linear): + lora_linear = LoRALinear(linear, rank=rank, alpha=alpha) + lora_linear = lora_linear.to(dtype=dtype, device=device) + setattr(module, attr, lora_linear) + return model + + +class LoRALinear(nn.Module): + """ + LoRA Linear layer. + This is a wrapper around nn.Linear that adds LoRA functionality. + LoRA is a technique for low-rank adaptation of large language models. + It allows for efficient fine-tuning of large models by injecting low-rank + matrices into the model's weights. + + :param linear_layer: The original linear layer to be wrapped. + :param rank: The rank of the LoRA matrices. + :param alpha: The scaling factor for the LoRA matrices. + """ + + def __init__(self, linear_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.linear = linear_layer + self.lora_A = nn.Linear(linear_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, linear_layer.out_features, bias=False) + self.scaling = alpha / rank + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.scaling * self.lora_B(self.lora_A(x)) diff --git a/src/metatrain/pet_9/modules/nef.py b/src/metatrain/pet_9/modules/nef.py new file mode 100644 index 0000000000..20ed6dc8fc --- /dev/null +++ b/src/metatrain/pet_9/modules/nef.py @@ -0,0 +1,248 @@ +""" +Module with functions to manipulate NEF (Node Edge Feature) arrays. + +The NEF representation is what the internals of PET use. +In the NEF representation, the first dimension is the center node +(i.e. the "i" node in an "i -> j" edge), and the second dimension +is the edges for that node. Not all center nodes have the same number +of edges, so padding is used to ensure that all nodes have the same +number of edges. + +Most of the functions have the purpose of converting between +edge arrays with shape (n_edges, ...) and NEF arrays with shape +(n_nodes, n_edges_per_node, ...). +""" + +from typing import List, Optional, Tuple + +import torch + + +def get_nef_indices( + centers: torch.Tensor, n_nodes: int, n_edges_per_node: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes tensors of indices useful to convert between edge + and NEF layouts; the usage and function of `nef_indices` and + `nef_to_edges_neighbor` is clear in the ``edge_array_to_nef`` + and ``nef_array_to_edges`` functions below. + + :param centers: A 1D tensor of shape (n_edges,) containing the + indices of the center nodes for each edge, with the center nodes + being the "i" node in an "i -> j" edge. + :param n_nodes: The number of nodes in the graph. + :param n_edges_per_node: The maximum number of edges per node. + + :return: A tuple with three tensors (nef_indices, nef_to_edges_neighbor, nef_mask). + In particular: + nef_array = edge_array[nef_indices] + edge_array = nef_array[centers, nef_to_edges_neighbor] + The third output, nef_mask, is a mask that can be used to + filter out the padding values in the NEF array, as different + nodes will have, in general, different number of edges. + """ + + bincount = torch.bincount(centers, minlength=n_nodes) + + arange = torch.arange(n_edges_per_node, device=centers.device) + arange_expanded = arange.view(1, -1).expand(n_nodes, -1) + nef_mask = arange_expanded < bincount.view(-1, 1) + + argsort = torch.argsort(centers, stable=True) + + nef_indices = torch.zeros( + (n_nodes, n_edges_per_node), dtype=torch.long, device=centers.device + ) + nef_indices[nef_mask] = argsort + + nef_to_edges_neighbor = torch.empty_like(centers, dtype=torch.long) + nef_to_edges_neighbor[argsort] = arange_expanded[nef_mask] + + return nef_indices, nef_to_edges_neighbor, nef_mask + + +def get_corresponding_edges(array: torch.Tensor) -> torch.Tensor: + """ + Computes the corresponding edge (i.e., the edge that goes in the + opposite direction) for each edge in the array; this is useful + in the message-passing operation. + + :param array: A 2D tensor of shape (n_edges, 5). For each i -> j + edge, the first column contains the index of the center node i, + the second column contains the index of the neighbor node j, + and the last three columns contain the cell shifts along x, y, and z + directions, respectively. + + :return: A 1D tensor of shape (n_edges,) containing, for each edge, + the index of the corresponding edge (i.e., the edge that goes + in the opposite direction). If the input array is empty, an + empty tensor is returned. + """ + + if array.numel() == 0: + return torch.empty((0,), dtype=array.dtype, device=array.device) + + array = array.to(torch.int64) # avoid overflow + + centers = array[:, 0] + neighbors = array[:, 1] + cell_shifts_x = array[:, 2] + cell_shifts_y = array[:, 3] + cell_shifts_z = array[:, 4] + + # will be useful later + negative_cell_shifts_x = -cell_shifts_x + negative_cell_shifts_y = -cell_shifts_y + negative_cell_shifts_z = -cell_shifts_z + + # create a unique identifier for each edge + # first, we shift the cell_shifts so that the minimum value is 0 + min_cell_shift_x = cell_shifts_x.min() + cell_shifts_x = cell_shifts_x - min_cell_shift_x + negative_cell_shifts_x = negative_cell_shifts_x - min_cell_shift_x + + min_cell_shift_y = cell_shifts_y.min() + cell_shifts_y = cell_shifts_y - min_cell_shift_y + negative_cell_shifts_y = negative_cell_shifts_y - min_cell_shift_y + + min_cell_shift_z = cell_shifts_z.min() + cell_shifts_z = cell_shifts_z - min_cell_shift_z + negative_cell_shifts_z = negative_cell_shifts_z - min_cell_shift_z + + max_centers_neigbors = centers.max() + 1 # same as neighbors.max() + 1 + max_shift_x = cell_shifts_x.max() + 1 + max_shift_y = cell_shifts_y.max() + 1 + max_shift_z = cell_shifts_z.max() + 1 + + size_1 = max_shift_z + size_2 = max_shift_y * size_1 + size_3 = max_shift_x * size_2 + size_4 = max_centers_neigbors * size_3 + + unique_id = ( + centers * size_4 + + neighbors * size_3 + + cell_shifts_x * size_2 + + cell_shifts_y * size_1 + + cell_shifts_z + ) + + # the inverse is the same, but centers and neighbors are swapped + # and we use the negative values of the cell_shifts + unique_id_inverse = ( + neighbors * size_4 + + centers * size_3 + + negative_cell_shifts_x * size_2 + + negative_cell_shifts_y * size_1 + + negative_cell_shifts_z + ) + + unique_id_argsort = unique_id.argsort() + unique_id_inverse_argsort = unique_id_inverse.argsort() + + corresponding_edges = torch.empty_like(centers) + corresponding_edges[unique_id_argsort] = unique_id_inverse_argsort + + return corresponding_edges.to(array.dtype) + + +def edge_array_to_nef( + edge_array: torch.Tensor, + nef_indices: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +) -> torch.Tensor: + """Converts an edge array to a NEF array. + + :param edge_array: A tensor where the first dimension is the index of + the edge, i.e. with shape (n_edges, ...). + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param mask: An optional boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. If provided, + the output NEF array will have the values in the positions + where the mask is False set to ``fill_value``. + :param fill_value: The value to use to fill the positions in the + NEF array where the mask is False. Only used if ``mask`` is + provided. + + :return: A tensor with the same information as ``edge_array``, + but in NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + If ``mask`` is provided, the values in the positions where + the mask is False are set to ``fill_value``. + """ + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges( + nef_array: torch.Tensor, centers: torch.Tensor, nef_to_edges_neighbor: torch.Tensor +) -> torch.Tensor: + """Converts a NEF array to an edge array. + + :param nef_array: A tensor where the first two dimensions are the + indices of the NEF layout, i.e. with shape (n_nodes, n_edges_per_node, ...). + :param centers: The indices of the center nodes for each edge. + :param nef_to_edges_neighbor: The indices of the edges for each + neighbor in the NEF layout, as returned by the ``get_nef_indices`` function. + + :return: A tensor with the same information as ``nef_array``, + but in edge layout, i.e. with shape (n_edges, ...). + """ + return nef_array[centers, nef_to_edges_neighbor] + + +def compute_reversed_neighbor_list( + nef_indices: torch.Tensor, + corresponding_edges: torch.Tensor, + nef_mask: torch.Tensor, +) -> torch.Tensor: + """ + Creates a reversed neighborlist, where for each + center atom `i` and its neighbor `j` in the original + neighborlist, the position of atom `i` in the list + of neighbors of atom `j` is returned. + + :param nef_indices: The indices to convert from edge to NEF layout, + as returned by the ``get_nef_indices`` function. + :param corresponding_edges: The indices of the corresponding edges, + as returned by the ``get_corresponding_edges`` function. + :param nef_mask: A boolean mask of shape (n_nodes, n_edges_per_node), + as returned by the ``get_nef_indices`` function. + :return: A tensor of the same shape as ``nef_indices``, + where each entry contains the position of the center + atom in the neighborlist of the corresponding neighbor atom. + """ + num_atoms, max_num_neighbors = nef_indices.shape + + flat_edge_indices = nef_indices.reshape(-1) + flat_positions = torch.arange(max_num_neighbors, device=nef_indices.device).repeat( + num_atoms + ) + flat_mask = nef_mask.reshape(-1) + + if flat_edge_indices.numel() == 0: + max_edge_index = 0 + else: + max_edge_index = int(flat_edge_indices.max().item()) + 1 + size: List[int] = [max_edge_index] + + edge_index_to_position = torch.full( + size, + 0, + dtype=torch.long, + device=nef_indices.device, + ) + edge_index_to_position[flat_edge_indices[flat_mask]] = flat_positions[flat_mask] + + reverse_edge_idx = corresponding_edges[nef_indices] + reversed_neighbor_list = edge_index_to_position[reverse_edge_idx] + reversed_neighbor_list = reversed_neighbor_list.masked_fill(~nef_mask, 0) + + return reversed_neighbor_list diff --git a/src/metatrain/pet_9/modules/readout.py b/src/metatrain/pet_9/modules/readout.py new file mode 100644 index 0000000000..34aa2d2675 --- /dev/null +++ b/src/metatrain/pet_9/modules/readout.py @@ -0,0 +1,50 @@ +from typing import List +import torch + +from metatomic.torch import System + + +class ReadoutLayer(torch.nn.Module): + def __init__( + self, + feature_dim: int, + output_dim: int, + num_atomic_types: int, + bias: bool = True, + num_experts: int = 1, + ): + super().__init__() + self.num_experts = num_experts + self.expert_embedding = torch.nn.Embedding( + num_atomic_types, num_experts + ) + self.linear = torch.nn.ModuleList( + [ + torch.nn.Linear(feature_dim, output_dim, bias=bias) + for _ in range(self.num_experts) + ] + ) + + def forward( + self, + batch_species_indices: torch.Tensor, + features: torch.Tensor, + ) -> torch.Tensor: + + expert_weights = torch.softmax( + self.expert_embedding(batch_species_indices), dim=-1 + ).unsqueeze(1) + if features.dim() == 3: + expert_weights = expert_weights.unsqueeze(1) + + expert_outputs = torch.stack( + [ + linear(features) + for linear in self.linear + ], + dim=-1, + ) + + return torch.sum( + expert_outputs * expert_weights, dim=-1 + ) \ No newline at end of file diff --git a/src/metatrain/pet_9/modules/structures.py b/src/metatrain/pet_9/modules/structures.py new file mode 100644 index 0000000000..f7e67f9eaf --- /dev/null +++ b/src/metatrain/pet_9/modules/structures.py @@ -0,0 +1,377 @@ +from typing import List, Optional, Tuple + +import torch +from metatensor.torch import Labels +from metatomic.torch import NeighborListOptions, System + +from .adaptive_cutoff import get_adaptive_cutoffs +from .nef import ( + compute_reversed_neighbor_list, + edge_array_to_nef, + get_corresponding_edges, + get_nef_indices, +) +from .utilities import cutoff_func_bump, cutoff_func_cosine + + +def concatenate_structures( + systems: List[System], + neighbor_list_options: NeighborListOptions, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, +]: + """ + Concatenate a list of systems into a single batch. + + :param systems: List of systems to concatenate. + :param neighbor_list_options: Options for the neighbor list. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the concatenated positions, centers, neighbors, + species, cells, cell shifts, system indices, and sample labels. + """ + + positions: List[torch.Tensor] = [] + centers: List[torch.Tensor] = [] + neighbors: List[torch.Tensor] = [] + species: List[torch.Tensor] = [] + cell_shifts: List[torch.Tensor] = [] + cells: List[torch.Tensor] = [] + system_indices: List[torch.Tensor] = [] + atom_indices: List[torch.Tensor] = [] + node_counter = 0 + + for i, system in enumerate(systems): + assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" + neighbor_list = system.get_neighbor_list(neighbor_list_options) + nl_values = neighbor_list.samples.values + + centers_values = nl_values[:, 0] + neighbors_values = nl_values[:, 1] + cell_shifts_values = nl_values[:, 2:] + + if selected_atoms is not None: + system_selected_atoms = selected_atoms.values[:, 1][ + selected_atoms.values[:, 0] == i + ] + unique_centers = torch.unique(centers_values) + system_selected_atoms = torch.unique( + torch.cat([system_selected_atoms, unique_centers]) + ) + # calculate the mapping from the ghost atoms to the real atoms + if torch.numel(unique_centers) == 0: + max_center_index = -1 + else: + max_center_index = int(unique_centers.max()) + ghost_to_real_index = torch.full( + [ + max_center_index + 1, + ], + -1, + device=centers_values.device, + dtype=centers_values.dtype, + ) + for j, unique_center_index in enumerate(unique_centers): + ghost_to_real_index[unique_center_index] = j + + centers_values = ghost_to_real_index[centers_values] + neighbors_values = ghost_to_real_index[neighbors_values] + else: + system_selected_atoms = torch.arange( + len(system), device=system.positions.device + ) + + positions.append(system.positions[system_selected_atoms]) + species.append(system.types[system_selected_atoms]) + + centers.append(centers_values + node_counter) + neighbors.append(neighbors_values + node_counter) + cell_shifts.append(cell_shifts_values) + + cells.append(system.cell) + + node_counter += len(system_selected_atoms) + system_indices.append( + torch.full((len(system_selected_atoms),), i, device=system.positions.device) + ) + atom_indices.append( + torch.arange(len(system_selected_atoms), device=system.positions.device) + ) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + cells = torch.stack(cells) + cell_shifts = torch.cat(cell_shifts) + system_indices = torch.cat(system_indices) + atom_indices = torch.cat(atom_indices) + + sample_values = torch.stack( + [system_indices, atom_indices], + dim=1, + ) + sample_labels = Labels( + names=["system", "atom"], + values=sample_values, + assume_unique=True, + ) + + return ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) + + +def systems_to_batch( + systems: List[System], + options: NeighborListOptions, + all_species_list: List[int], + species_to_species_index: torch.Tensor, + cutoff_function: str, + cutoff_width: float, + num_neighbors_adaptive: Optional[float] = None, + selected_atoms: Optional[Labels] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Labels, + torch.Tensor, + torch.Tensor, +]: + """ + Converts a list of systems to a batch required for the PET model. + + :param systems: List of systems to convert to a batch. + :param options: Options for the neighbor list. + :param all_species_list: List of all atomic species in the dataset. + :param species_to_species_index: Mapping from atomic species to species indices. + :param cutoff_function: Type of the smoothing function at the cutoff. + :param cutoff_width: Width of the cutoff function for a cutoff mask. + :param num_neighbors_adaptive: Optional maximum number of neighbors per atom. + If provided, the adaptive cutoff scheme will be used for each atom to + approximately select this number of neighbors. + :param selected_atoms: Optional labels of selected atoms to include in the batch. + :return: A tuple containing the batch tensors. + The batch consists of the following tensors: + - `element_indices_nodes`: The atomic species of the central atoms + - `element_indices_neighbors`: The atomic species of the neighboring atoms + - `edge_vectors`: The cartesian edge vectors between the central atoms and their + neighbors + - `edge_distances`: The distances between the central atoms and their neighbors + - `padding_mask`: A padding mask indicating which neighbors are real, and which + are padded + - `reverse_neighbor_index`: The reversed neighbor list for each central atom + - `cutoff_factors`: The cutoff function values for each edge + - `system_indices`: The system index for each atom in the batch + - `sample_labels`: Labels indicating the system and atom indices for each atom + + """ + ( + positions, + centers, + neighbors, + species, + cells, + cell_shifts, + system_indices, + sample_labels, + ) = concatenate_structures(systems, options, selected_atoms) + + # somehow the backward of this operation is very slow at evaluation, + # where there is only one cell, therefore we simplify the calculation + # for that case + if len(cells) == 1: + cell_contributions = cell_shifts.to(cells.dtype) @ cells[0] + else: + cell_contributions = torch.einsum( + "ab, abc -> ac", + cell_shifts.to(cells.dtype), + cells[system_indices[centers]], + ) + edge_vectors = positions[neighbors] - positions[centers] + cell_contributions + edge_distances = torch.norm(edge_vectors, dim=-1) + 1e-15 + + if selected_atoms is not None: + if torch.numel(centers) == 0: + num_nodes = 0 + else: + num_nodes = int(centers.max()) + 1 + else: + num_nodes = len(positions) + + if num_neighbors_adaptive is not None: + with torch.profiler.record_function("PET::get_adaptive_cutoffs"): + # Adaptive cutoff scheme to approximately select `num_neighbors_adaptive` + # neighbors for each atom + atomic_cutoffs = get_adaptive_cutoffs( + centers, + edge_distances, + num_neighbors_adaptive, + num_nodes, + options.cutoff, + cutoff_width=cutoff_width, + ) + # Symmetrize the cutoffs between pairs of atoms (PET needs this symmetry + # due to its corresponding edge indexing ij -> ji) + pair_cutoffs = (atomic_cutoffs[centers] + atomic_cutoffs[neighbors]) / 2.0 + with torch.profiler.record_function("PET::adaptive_cutoff_masking"): + # Apply cutoff mask + cutoff_mask = edge_distances <= pair_cutoffs + + pair_cutoffs = pair_cutoffs[cutoff_mask] + centers = centers[cutoff_mask] + neighbors = neighbors[cutoff_mask] + edge_vectors = edge_vectors[cutoff_mask] + cell_shifts = cell_shifts[cutoff_mask] + edge_distances = edge_distances[cutoff_mask] + else: + pair_cutoffs = options.cutoff * torch.ones( + len(centers), device=positions.device, dtype=positions.dtype + ) + + num_neighbors = torch.bincount(centers, minlength=num_nodes) + max_edges_per_node = int(torch.max(num_neighbors)) + + # uncomment these to print out stats on the adaptive cutoff behavior + # print("adaptive_cutoffs", *pair_cutoffs.tolist()) + # print("num_neighbors", *num_neighbors.tolist()) + + if cutoff_function.lower() == "bump": + # use bump switching function for adaptive cutoff + cutoff_factors = cutoff_func_bump(edge_distances, pair_cutoffs, cutoff_width) + elif cutoff_function.lower() == "cosine": + # backward-compatible cosine swithcing for fixed cutoff + cutoff_factors = cutoff_func_cosine(edge_distances, pair_cutoffs, cutoff_width) + else: + raise ValueError( + f"Unknown cutoff function type: {cutoff_function}. " + f"Supported types are 'Cosine' and 'Bump'." + ) + + # Convert to NEF (Node-Edge-Feature) format: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, num_nodes, max_edges_per_node + ) + + # Element indices + element_indices_nodes = species_to_species_index[species] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + edge_distances = torch.sqrt(torch.sum(edge_vectors**2, dim=2) + 1e-15) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + cutoff_factors = edge_array_to_nef(cutoff_factors, nef_indices, nef_mask, 0.0) + + corresponding_edges = get_corresponding_edges( + torch.concatenate( + [centers.unsqueeze(-1), neighbors.unsqueeze(-1), cell_shifts], + dim=-1, + ) + ) + + # These are the two arrays we need for message passing with edge reversals, + # if indexing happens in a two-dimensional way: + # edges_ji = edges_ij[reversed_neighbor_list, neighbors_index] + reversed_neighbor_list = compute_reversed_neighbor_list( + nef_indices, corresponding_edges, nef_mask + ) + neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) + + # Here, we compute the array that allows indexing into a flattened + # version of the edge array (where the first two dimensions are merged): + reverse_neighbor_index = ( + neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list + ) + # At this point, we have `reverse_neighbor_index[~nef_mask] = 0`, which however + # creates too many of the same index which slows down backward enormously. + # (See see https://github.com/pytorch/pytorch/issues/41162) + # We therefore replace the padded indices with a sequence of unique indices. + reverse_neighbor_index[~nef_mask] = torch.arange( + int(torch.sum(~nef_mask)), device=reverse_neighbor_index.device + ) + + return ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + nef_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + sample_labels, + centers, + nef_to_edges_neighbor, + ) + + +def get_pair_sample_labels( + systems: List[System], + sample_labels: Labels, + nl_options: NeighborListOptions, + device: torch.device, +) -> Labels: + """ + Builds the pair samples labels for the input ``systems``, based on the pre-computed + neighbor list. These are 'off-site', i.e. not including self-interactions. + + :param systems: List of systems to build the pair sample labels for. + :param sample_labels: The sample labels for per-atom quantities. + :param nl_options: The neighbor list options to use for building the offsite labels. + :param device: The device to put the labels on. + :return: A dictionary with the pair sample labels for the onsite and offsite blocks. + """ + sample_names = [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ] + + pair_sample_values = [] + for system_idx, system in enumerate(systems): + neighbor_list = system.get_neighbor_list(nl_options) + nl_values = neighbor_list.samples.values + + pair_sample_values.append( + torch.hstack( + [ + torch.full( + (nl_values.shape[0], 1), + system_idx, + dtype=torch.int32, + device=device, + ), + nl_values, + ], + ) + ) + pair_sample_values = torch.vstack(pair_sample_values) + pair_sample_labels = Labels(sample_names, pair_sample_values).to(device=device) + + return pair_sample_labels diff --git a/src/metatrain/pet_9/modules/transformer.py b/src/metatrain/pet_9/modules/transformer.py new file mode 100644 index 0000000000..decb395090 --- /dev/null +++ b/src/metatrain/pet_9/modules/transformer.py @@ -0,0 +1,555 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from .utilities import DummyModule +from sphericart.torch import SolidHarmonics + + +AVAILABLE_NORMALIZATIONS = ["LayerNorm", "RMSNorm"] +AVAILABLE_TRANSFORMER_TYPES = ["PostLN", "PreLN"] +AVAILABLE_ACTIVATIONS = ["SiLU", "SwiGLU"] + + +class FeedForward(nn.Module): + def __init__(self, d_model: int, dim_feedforward: int, activation: str) -> None: + super().__init__() + + # Check if activation is "swiglu" string + if activation.lower() == "swiglu": + # SwiGLU mode: single projection produces both "value" and "gate" + self.w_in = nn.Linear(d_model, 2 * dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = torch.nn.Identity() + self.is_swiglu = True + else: + # Standard mode: regular activation function + self.w_in = nn.Linear(d_model, dim_feedforward) + self.w_out = nn.Linear(dim_feedforward, d_model) + self.activation = getattr(F, activation.lower()) + self.is_swiglu = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_swiglu: + # SwiGLU activation: split into value and gate + v, g = self.w_in(x).chunk(2, dim=-1) + x = v * torch.sigmoid(g) + x = self.w_out(x) + else: + # Standard activation + x = self.w_in(x) + x = self.activation(x) + x = self.w_out(x) + return x + + +class AttentionBlock(nn.Module): + """ + Multi-head attention block. + + :param total_dim: The total dimension of the input and output tensors. + :param num_heads: The number of attention heads. + :param temperature: An additional scaling factor for attention scores. + This is combined with the standard scaling by the square root of + the head dimension. + :param epsilon: A small value to avoid division by zero. + """ + + def __init__( + self, + total_dim: int, + num_heads: int, + temperature: float, + epsilon: float = 1e-15, + ) -> None: + super(AttentionBlock, self).__init__() + + self.input_linear = nn.Linear(total_dim, 3 * total_dim) + self.output_linear = nn.Linear(total_dim, total_dim) + + self.num_heads = num_heads + self.epsilon = epsilon + self.temperature = temperature + if total_dim % num_heads != 0: + raise ValueError("total dimension is not divisible by the number of heads") + self.head_dim = total_dim // num_heads + + def forward( + self, x: torch.Tensor, cutoff_factors: torch.Tensor, use_manual_attention: bool + ) -> torch.Tensor: + """ + Forward pass for the attention block. + + :param x: The input tensor, of shape (batch_size, seq_length, total_dim) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: The output tensor, of shape (batch_size, seq_length, total_dim) + """ + initial_shape = x.shape + x = self.input_linear(x) + x = x.reshape( + initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim + ) + x = x.permute(2, 0, 3, 1, 4) + + queries, keys, values = x[0], x[1], x[2] + attn_weights = torch.clamp(cutoff_factors[:, None, :, :], self.epsilon) + attn_weights = torch.log(attn_weights) + if use_manual_attention: + x = manual_attention(queries, keys, values, attn_weights, self.temperature) + else: + x = torch.nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_weights, + scale=1.0 / (self.head_dim**0.5 * self.temperature), + ) + x = x.transpose(1, 2).reshape(initial_shape) + x = self.output_linear(x) + return x + + +class TransformerLayer(torch.nn.Module): + """ + Single layer of a Transformer. + + :param d_model: The dimension of the model. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param temperature: An additional scaling factor for attention scores. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + temperature: float = 1.0, + ) -> None: + super(TransformerLayer, self).__init__() + self.attention = AttentionBlock(d_model, n_heads, temperature) + self.transformer_type = transformer_type + self.d_model = d_model + norm_class = getattr(nn, norm) + self.norm_attention = norm_class(d_model) + self.norm_mlp = norm_class(d_model) + self.mlp = FeedForward(d_model, dim_feedforward, activation) + self.expanded_node_features = False + if dim_node_features != d_model: + self.expanded_node_features = True + self.center_contraction = nn.Linear(dim_node_features, d_model) + self.center_expansion = nn.Linear(d_model, dim_node_features) + self.norm_center_features = norm_class(dim_node_features) + self.center_mlp = FeedForward( + dim_node_features, 2 * dim_node_features, activation + ) + else: + self.center_contraction = torch.nn.Identity() + self.center_expansion = torch.nn.Identity() + self.norm_center_features = torch.nn.Identity() + self.center_mlp = torch.nn.Identity() + + def _forward_pre_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + new_tokens = self.attention( + self.norm_attention(tokens), cutoff_factors, use_manual_attention + ) + output_node_embeddings, output_edge_embeddings = torch.split( + new_tokens, [1, new_tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + + output_edge_embeddings = edge_embeddings + output_edge_embeddings + output_edge_embeddings = output_edge_embeddings + self.mlp( + self.norm_mlp(output_edge_embeddings) + ) + + return output_node_embeddings, output_edge_embeddings + + def _forward_post_ln_impl( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.expanded_node_features: + input_node_embeddings = self.center_contraction(node_embeddings) + else: + input_node_embeddings = node_embeddings + tokens = torch.cat([input_node_embeddings, edge_embeddings], dim=1) + tokens = self.norm_attention( + tokens + self.attention(tokens, cutoff_factors, use_manual_attention) + ) + tokens = self.norm_mlp(tokens + self.mlp(tokens)) + output_node_embeddings, output_edge_embeddings = torch.split( + tokens, [1, tokens.shape[1] - 1], dim=1 + ) + if self.expanded_node_features: + output_node_embeddings = node_embeddings + self.center_expansion( + output_node_embeddings + ) + output_node_embeddings = output_node_embeddings + self.center_mlp( + self.norm_center_features(output_node_embeddings) + ) + return output_node_embeddings, output_edge_embeddings + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for a single Transformer layer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + if self.transformer_type == "PostLN": + node_embeddings, edge_embeddings = self._forward_post_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + if self.transformer_type == "PreLN": + node_embeddings, edge_embeddings = self._forward_pre_ln_impl( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class Transformer(torch.nn.Module): + """ + Transformer implementation. + + :param d_model: The dimension of the model. + :param num_layers: The number of transformer layers. + :param n_heads: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param attention_temperature: The temperature scaling factor for attention + scores. This is combined with the standard scaling by the square root of + the head dimension. + """ + + def __init__( + self, + d_model: int, + num_layers: int, + n_heads: int, + dim_node_features: int, + dim_feedforward: int = 512, + norm: str = "LayerNorm", + activation: str = "SiLU", + transformer_type: str = "PostLN", + attention_temperature: float = 1.0, + ) -> None: + super(Transformer, self).__init__() + if norm not in AVAILABLE_NORMALIZATIONS: + raise ValueError( + f"Unknown normalization flag: {norm}. " + f"Please choose from: {AVAILABLE_NORMALIZATIONS}" + ) + + if transformer_type not in AVAILABLE_TRANSFORMER_TYPES: + raise ValueError( + f"Unknown transformer flag: {transformer_type}. " + f"Please choose from: {AVAILABLE_TRANSFORMER_TYPES}" + ) + self.transformer_type = transformer_type + + if activation not in AVAILABLE_ACTIVATIONS: + raise ValueError( + f"Unknown activation flag: {activation}. " + f"Please choose from: {AVAILABLE_ACTIVATIONS}" + ) + + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model=d_model, + n_heads=n_heads, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + temperature=attention_temperature, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_embeddings: torch.Tensor, + edge_embeddings: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the Transformer. + + :param node_embeddings: The input node embeddings, of shape + (batch_size, d_model) + :param edge_embeddings: The input edge embeddings, of shape + (batch_size, seq_length, d_model) + :param cutoff_factors: The cutoff factors for the edges, of shape + (batch_size, seq_length, seq_length) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (batch_size, d_model) + - The output edge embeddings, of shape (batch_size, seq_length, d_model) + """ + for layer in self.layers: + node_embeddings, edge_embeddings = layer( + node_embeddings, edge_embeddings, cutoff_factors, use_manual_attention + ) + return node_embeddings, edge_embeddings + + +class CartesianTransformer(torch.nn.Module): + """ + Cartesian Transformer implementation for handling 3D coordinates. + + :param cutoff: The cutoff distance for neighbor interactions. + :param cutoff_width: The width of the cutoff function. + :param d_model: The dimension of the model. + :param n_head: The number of attention heads. + :param dim_node_features: The dimension of the node features. + :param dim_feedforward: The dimension of the feedforward network. + :param n_layers: The number of transformer layers. + :param norm: The normalization type, either "LayerNorm" or "RMSNorm". + :param activation: The activation function, either "SiLU" or "SwiGLU". + :param attention_temperature: The temperature scaling factor for attention scores. + :param transformer_type: The type of transformer, either "PostLN" or "PreLN". + :param n_atomic_species: The number of atomic species. + :param is_first: Whether this is the first transformer in the model. + """ + + def __init__( + self, + cutoff: float, + cutoff_width: float, + d_model: int, + n_head: int, + dim_node_features: int, + dim_feedforward: int, + n_layers: int, + norm: str, + activation: str, + attention_temperature: float, + transformer_type: str, + n_atomic_species: int, + is_first: bool, + ) -> None: + super(CartesianTransformer, self).__init__() + self.is_first = is_first + self.cutoff = cutoff + self.cutoff_width = cutoff_width + self.trans = Transformer( + d_model=d_model, + num_layers=n_layers, + n_heads=n_head, + dim_node_features=dim_node_features, + dim_feedforward=dim_feedforward, + norm=norm, + activation=activation, + transformer_type=transformer_type, + attention_temperature=attention_temperature, + ) + + self.spherical_harmonics = SolidHarmonics(l_max=10) + self.edge_embedder = nn.Linear(11**2, d_model) + self.rmsnorm = nn.LayerNorm(d_model) + + if not is_first: + n_merge = 3 + else: + n_merge = 2 + + self.compress = nn.Sequential( + nn.Linear(n_merge * d_model, d_model), + torch.nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.neighbor_embedder = DummyModule() # for torchscript + if not is_first: + self.neighbor_embedder = nn.Embedding(n_atomic_species, d_model) + + def forward( + self, + input_node_embeddings: torch.Tensor, + input_messages: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + padding_mask: torch.Tensor, + edge_distances: torch.Tensor, + cutoff_factors: torch.Tensor, + use_manual_attention: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the CartesianTransformer. + + :param input_node_embeddings: The input node embeddings, of shape + (n_nodes, d_model) + :param input_messages: The input messages to the transformer, of shape + (n_nodes, max_num_neighbors, d_model) + :param element_indices_neighbors: The atomic species of the neighboring atoms, + of shape (n_nodes, max_num_neighbors) + :param edge_vectors: The cartesian edge vectors between the central atoms and + their neighbors, of shape (n_nodes, max_num_neighbors, 3) + :param padding_mask: A padding mask indicating which neighbors are real, and + which are padded, of shape (n_nodes, max_num_neighbors) + :param edge_distances: The distances between the central atoms and their + neighbors, of shape (n_nodes, max_num_neighbors) + :param cutoff_factors: The cutoff factors for the edges, of shape + (n_nodes, max_num_neighbors) + :param use_manual_attention: Whether to use the manual attention implementation + (which supports double backward, needed for training with conservative + forces), or the built-in PyTorch attention (which does not support double + backward). + :return: A tuple containing: + - The output node embeddings, of shape (n_nodes, d_model) + - The output edge embeddings, of shape (n_nodes, max_num_neighbors, d_model) + """ + node_embeddings = input_node_embeddings + # edge_embeddings = [edge_vectors, edge_distances[:, :, None]] + # edge_embeddings = torch.cat(edge_embeddings, dim=2) + edge_embeddings = self.spherical_harmonics( + edge_vectors.reshape(-1, 3) + ).reshape(edge_vectors.shape[0], edge_vectors.shape[1], -1) + edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.rmsnorm(edge_embeddings) + + if not self.is_first: + neighbor_elements_embeddings = self.neighbor_embedder( + element_indices_neighbors + ) + edge_tokens = torch.cat( + [edge_embeddings, neighbor_elements_embeddings, input_messages], dim=2 + ) + else: + neighbor_elements_embeddings = torch.empty( + 0, device=edge_vectors.device, dtype=edge_vectors.dtype + ) # for torch script + edge_tokens = torch.cat([edge_embeddings, input_messages], dim=2) + + edge_tokens = self.compress(edge_tokens) + # tokens = torch.cat([node_elements_embedding[:, None, :], tokens], dim=1) + + padding_mask_with_central_token = torch.ones( + padding_mask.shape[0], dtype=torch.bool, device=padding_mask.device + ) + total_padding_mask = torch.cat( + [padding_mask_with_central_token[:, None], padding_mask], dim=1 + ) + + cutoff_subfactors = torch.ones( + padding_mask.shape[0], + dtype=cutoff_factors.dtype, + device=padding_mask.device, + ) + cutoff_factors = torch.cat([cutoff_subfactors[:, None], cutoff_factors], dim=1) + cutoff_factors[~total_padding_mask] = 0.0 + + cutoff_factors = cutoff_factors[:, None, :] + cutoff_factors = cutoff_factors.repeat(1, cutoff_factors.shape[2], 1) + + initial_num_tokens = edge_vectors.shape[1] + max_num_tokens = input_messages.shape[1] + + output_node_embeddings, output_edge_embeddings = self.trans( + node_embeddings[:, None, :], + edge_tokens[:, :max_num_tokens, :], + cutoff_factors=cutoff_factors[ + :, : (max_num_tokens + 1), : (max_num_tokens + 1) + ], + use_manual_attention=use_manual_attention, + ) + if max_num_tokens < initial_num_tokens: + padding = torch.zeros( + output_edge_embeddings.shape[0], + initial_num_tokens - max_num_tokens, + output_edge_embeddings.shape[2], + device=output_edge_embeddings.device, + ) + output_edge_embeddings = torch.cat([output_edge_embeddings, padding], dim=1) + output_node_embeddings = output_node_embeddings.squeeze(1) + return output_node_embeddings, output_edge_embeddings + + +def manual_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor, + temperature: float, +) -> torch.Tensor: + """ + Implements the attention operation manually, using basic PyTorch operations. + We need it because the built-in PyTorch attention does not support double backward, + which is needed when training with conservative forces. + + :param q: The queries + :param k: The keys + :param v: The values + :param attn_mask: The attention mask + :param temperature: An additional scaling factor for attention scores. + :return: The result of the attention operation + """ + attention_weights = ( + torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5 * temperature) + ) + attn_mask + attention_weights = attention_weights.softmax(dim=-1) + attention_output = torch.matmul(attention_weights, v) + return attention_output \ No newline at end of file diff --git a/src/metatrain/pet_9/modules/utilities.py b/src/metatrain/pet_9/modules/utilities.py new file mode 100644 index 0000000000..5795c5aed8 --- /dev/null +++ b/src/metatrain/pet_9/modules/utilities.py @@ -0,0 +1,63 @@ +import torch + + +def cutoff_func_bump( + values: torch.Tensor, cutoff: torch.Tensor, width: float, eps: float = 1e-6 +) -> torch.Tensor: + """ + Bump cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :param eps: Avoid computing at values too close to the edges. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + f[mask_active] = 0.5 * ( + 1 + torch.tanh(1 / torch.tan(torch.pi * scaled_values[mask_active])) + ) + f[mask_smaller] = 1.0 + + return f + + +def cutoff_func_cosine( + values: torch.Tensor, cutoff: torch.Tensor, width: float +) -> torch.Tensor: + """ + Cosine cutoff function. + + :param values: Distances at which to evaluate the cutoff function. + :param cutoff: Cutoff radius for each node. + :param width: Width of the cutoff region. + :return: Values of the cutoff function at the specified distances. + """ + + scaled_values = (values - (cutoff - width)) / width + + mask_smaller = scaled_values <= 0.0 + mask_active = (scaled_values > 0.0) & (scaled_values < 1.0) + + f = torch.zeros_like(scaled_values) + + f[mask_active] = 0.5 + 0.5 * torch.cos(torch.pi * scaled_values[mask_active]) + f[mask_smaller] = 1.0 + return f + + +class DummyModule(torch.nn.Module): + """Dummy torch module to make torchscript happy. + This model should never be run""" + + def __init__(self) -> None: + super(DummyModule, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError("This model should never be run") diff --git a/src/metatrain/pet_9/trainer.py b/src/metatrain/pet_9/trainer.py new file mode 100644 index 0000000000..49e82212bc --- /dev/null +++ b/src/metatrain/pet_9/trainer.py @@ -0,0 +1,639 @@ +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.batch_utils import should_skip_batch +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 12 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PET.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + is_finetune = self.hypers["finetune"]["read_from"] is not None + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with PET, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Apply fine-tuning strategy if provided + if is_finetune: + assert self.hypers["finetune"]["read_from"] is not None # for mypy + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + method = self.hypers["finetune"]["method"] + num_params = sum(p.numel() for p in model.parameters()) + num_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info(f"Applied finetuning strategy: {method}") + logging.info( + f"Number of trainable parameters: {num_trainable_params} " + f"[{num_trainable_params / num_params:.2%} %]" + ) + inherit_heads = self.hypers["finetune"]["inherit_heads"] + if inherit_heads: + logging.info( + "Inheriting initial weights for heads and last layers for targets: " + f"from {list(inherit_heads.values())} to " + f"{list(inherit_heads.keys())}" + ) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of PET are always in float64 (to avoid numerical errors in + # the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + rotational_augmenter = RotationalAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + rotational_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + batch_atom_bounds=self.hypers["batch_atom_bounds"], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator(targets=train_targets, config=loss_hypers) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not is_finetune: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None and not is_finetune: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + + start_epoch = -1 if self.epoch is None else self.epoch + epoch = start_epoch + + # Save the untrained model checkpoint: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Saving untrained model checkpoint") + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_untrained.ckpt", + ) + + # Train the model: + logging.info("Starting training") + start_epoch = start_epoch + 1 + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + scaled_predictions = (model.module if is_distributed else model).scaler( + systems, predictions + ) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + if self.best_model_state_dict is not None: + self.best_model_state_dict["finetune_config"] = model.finetune_config + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint