From 335e80c9c7ff1d486d57f5fc79576b94bd73a4d8 Mon Sep 17 00:00:00 2001 From: Pol Date: Mon, 15 Dec 2025 16:27:43 +0100 Subject: [PATCH 1/5] First sketch of graph2mat --- pyproject.toml | 5 + .../experimental/graph2mat/.gitignore | 1 + .../experimental/graph2mat/__init__.py | 15 + .../experimental/graph2mat/checkpoints.py | 0 .../experimental/graph2mat/documentation.py | 101 ++++ src/metatrain/experimental/graph2mat/model.py | 470 ++++++++++++++ .../graph2mat/modules/__init__.py | 0 .../graph2mat/modules/edge_embedding.py | 42 ++ .../experimental/graph2mat/trainer.py | 571 ++++++++++++++++++ .../experimental/graph2mat/utils/__init__.py | 0 .../experimental/graph2mat/utils/basis.py | 25 + .../experimental/graph2mat/utils/dataset.py | 237 ++++++++ .../experimental/graph2mat/utils/mtt.py | 195 ++++++ .../graph2mat/utils/structures.py | 122 ++++ src/metatrain/share/base_hypers.py | 6 +- src/metatrain/utils/data/dataset.py | 4 +- src/metatrain/utils/data/target_info.py | 62 +- 17 files changed, 1840 insertions(+), 16 deletions(-) create mode 100644 src/metatrain/experimental/graph2mat/.gitignore create mode 100644 src/metatrain/experimental/graph2mat/__init__.py create mode 100644 src/metatrain/experimental/graph2mat/checkpoints.py create mode 100644 src/metatrain/experimental/graph2mat/documentation.py create mode 100644 src/metatrain/experimental/graph2mat/model.py create mode 100644 src/metatrain/experimental/graph2mat/modules/__init__.py create mode 100644 src/metatrain/experimental/graph2mat/modules/edge_embedding.py create mode 100644 src/metatrain/experimental/graph2mat/trainer.py create mode 100644 src/metatrain/experimental/graph2mat/utils/__init__.py create mode 100644 src/metatrain/experimental/graph2mat/utils/basis.py create mode 100644 src/metatrain/experimental/graph2mat/utils/dataset.py create mode 100644 src/metatrain/experimental/graph2mat/utils/mtt.py create mode 100644 src/metatrain/experimental/graph2mat/utils/structures.py diff --git a/pyproject.toml b/pyproject.toml index ab4a6ab698..f9bb128ff8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,11 @@ mace = [ "mace-torch >=0.3.14", "e3nn" ] +graph2mat = [ + "graph2mat", + "e3nn", + "sisl", +] [tool.check-manifest] ignore = ["src/metatrain/_version.py"] diff --git a/src/metatrain/experimental/graph2mat/.gitignore b/src/metatrain/experimental/graph2mat/.gitignore new file mode 100644 index 0000000000..4ced27f25c --- /dev/null +++ b/src/metatrain/experimental/graph2mat/.gitignore @@ -0,0 +1 @@ +utils/tests.ipynb \ No newline at end of file diff --git a/src/metatrain/experimental/graph2mat/__init__.py b/src/metatrain/experimental/graph2mat/__init__.py new file mode 100644 index 0000000000..18522d46ea --- /dev/null +++ b/src/metatrain/experimental/graph2mat/__init__.py @@ -0,0 +1,15 @@ +from .model import MetaGraph2Mat +from .trainer import Trainer +from .utils.mtt import _get_basis_target_info + + +__model__ = MetaGraph2Mat +__trainer__ = Trainer + +__authors__ = [ + ("Pol Febrer ", "@pfebrer"), +] + +__maintainers__ = [ + ("Pol Febrer ", "@pfebrer"), +] diff --git a/src/metatrain/experimental/graph2mat/checkpoints.py b/src/metatrain/experimental/graph2mat/checkpoints.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/metatrain/experimental/graph2mat/documentation.py b/src/metatrain/experimental/graph2mat/documentation.py new file mode 100644 index 0000000000..e1dbcce44b --- /dev/null +++ b/src/metatrain/experimental/graph2mat/documentation.py @@ -0,0 +1,101 @@ +""" +Graph2Mat +========= + +Interface of ``Graph2Mat`` to all architectures in ``metatrain``. +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.utils.loss import LossSpecification + + +class ModelHypers(TypedDict): + featurizer_architecture: dict + """Architecture that provides the features for graph2mat. + + This hyperparameter can contain the full specification for the + architecture, i.e. everything that goes inside the ``architecture`` + field of a normal training run for that architecture. + """ + basis_yaml: str = "." + """Yaml file with the full basis specification for graph2mat. + + This file contains a list, with each item being a dictionary + to initialize a ``graph2mat.PointBasis`` object. + """ + basis_grouping: Literal["point_type", "basis_shape", "max"] = "point_type" + """The way in which graph2mat should group basis (to reduce the number of heads)""" + node_hidden_irreps: str = "20x0e+20x1o+20x2e" + """Irreps to ask for to the featurizer (per atom). + + Graph2Mat will take these features as input. + """ + edge_hidden_irreps: str = "10x0e+10x1o+10x2e" + """Hidden irreps for the edges inside graph2mat""" + + +class TrainerHypers(TypedDict): + # Optimizer hypers + optimizer: str = "Adam" + """Optimizer for parameter optimization. + + We just take the class from torch.optim by name, so make + sure it is a valid torch optimizer (including possible + uppercase/lowercase differences). + """ + optimizer_kwargs: dict = {"lr": 0.01} + """Keyword arguments to pass to the optimizer. + + These will depend on the optimizer chosen. + """ + + # LR scheduler hypers + lr_scheduler: Optional[str] = "ReduceLROnPlateau" # Named "scheduler" in MACE + """Learning rate scheduler to use. + + We just take the class from torch.optim.lr_scheduler by name, so make + sure it is a valid torch scheduler (including possible + uppercase/lowercase differences). + + None means no scheduler will be used. + """ + lr_scheduler_kwargs: dict = {} + """Keyword arguments to pass to the learning rate scheduler. + + These will depend on the scheduler chosen. + """ + + # General training parameters that are shared across architectures + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for DDP communication""" + batch_size: int = 16 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 1000 + """Number of epochs.""" + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + num_workers: Optional[int] = None + """Number of workers for data loading. If not provided, it is set automatically.""" + log_mae: bool = True + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + grad_clip_norm: float = 1.0 + """Maximum gradient norm value""" + loss: str | dict[str, LossSpecification] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" diff --git a/src/metatrain/experimental/graph2mat/model.py b/src/metatrain/experimental/graph2mat/model.py new file mode 100644 index 0000000000..c7f7b7475c --- /dev/null +++ b/src/metatrain/experimental/graph2mat/model.py @@ -0,0 +1,470 @@ +import logging +from typing import Any, Dict, List, Literal, Optional + +import numpy as np +import torch +from e3nn import o3 +from graph2mat import MatrixDataProcessor +from graph2mat.bindings.e3nn import E3nnGraph2Mat +from metatensor.torch import Labels, TensorMap +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.architectures import get_default_hypers, import_architecture +from metatrain.utils.data import DatasetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.metadata import merge_metadata + +from .documentation import ModelHypers +from .modules.edge_embedding import BesselBasis +from .utils.basis import get_basis_table_from_yaml +from .utils.mtt import g2m_labels_to_tensormap, split_dataset_info +from .utils.structures import create_batch, get_edge_vectors_and_lengths + + +class MetaGraph2Mat(ModelInterface[ModelHypers]): + """Interface of MACE for metatrain.""" + + __checkpoint_version__ = 1 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float64, torch.float32] + __default_metadata__ = ModelMetadata( + references={ + "architecture": [ + "https://iopscience.iop.org/article/10.1088/2632-2153/adc871" + ] + } + ) + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + # ----------------------------------------------------------- + # Split dataset info into targets that Graph2Mat handles + # and those that will be handled by the featurizer itself + # ----------------------------------------------------------- + self.featurizer_dataset_info, self.graph2mat_dataset_info = split_dataset_info( + dataset_info=dataset_info, + node_hidden_irreps=self.hypers["node_hidden_irreps"], + ) + + # --------------------------- + # Initialize the featurizer + # --------------------------- + # We use the "featurizer_architecture" hyper to initialize a model. + + featurizer_name = self.hypers["featurizer_architecture"]["name"] + featurizer_arch = import_architecture(featurizer_name) + default_hypers = get_default_hypers(featurizer_name) + model_hypers = { + **default_hypers["model"], + **self.hypers["featurizer_architecture"].get("model", {}), + } + self.featurizer_model = featurizer_arch.__model__( + hypers=model_hypers, + dataset_info=self.featurizer_dataset_info, + ) + + # ---------------------------------------------------- + # Prepare things for initializing Graph2Mat + # ---------------------------------------------------- + + # Get the basis, this will likely be a different basis table + # per target in the end, let's see + basis_table = get_basis_table_from_yaml(self.hypers["basis_yaml"]) + + # Atomic types, and helper to convert from atomic type (Z) to index + # in the basis table. + self.atomic_types = [atom.Z for atom in basis_table.atoms] + self.register_buffer( + "atomic_types_to_species_index", + torch.zeros(max(self.atomic_types) + 1, dtype=torch.int64), + ) + for i, atomic_type in enumerate(self.atomic_types): + self.atomic_types_to_species_index[atomic_type] = i + + # Functions to embed edges for graph2mat. + # Radial embedding (i.e. embedding of the edge length). + n_basis = 8 + self.radial_embedding = BesselBasis( + r_max=np.max(basis_table.R), num_basis=n_basis + ) + + # Embedding of the direction of the edge. + sh_irreps = o3.Irreps.spherical_harmonics(2) + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + + # Irreps for all the inputs that graph2mat will take. + graph2mat_irreps = dict( + # One hot encoding of species + node_attrs_irreps=o3.Irreps("0e") * len(self.atomic_types), + # Features coming from the featurizer + node_feats_irreps=o3.Irreps(self.hypers["node_hidden_irreps"]), + # Embedding of the edges direction. + edge_attrs_irreps=sh_irreps, + # Embedding of the edges length. + edge_feats_irreps=o3.Irreps(f"{n_basis}x0e"), + # Internal irreps for graph2mat + edge_hidden_irreps=o3.Irreps(self.hypers["edge_hidden_irreps"]), + ) + + # ---------------------------------------------------- + # Initialize one Graph2Mat per target + # ---------------------------------------------------- + self.graph2mats = torch.nn.ModuleDict() + self.graph2mat_nls: dict[str, NeighborListOptions] = {} + self.graph2mat_processors: dict[str, MatrixDataProcessor] = {} + for i, target_name in enumerate(self.graph2mat_dataset_info.targets): + # Get the matrix processor for this target. + data_processor = MatrixDataProcessor( + basis_table=basis_table, + symmetric_matrix=True, + sub_point_matrix=False, + out_matrix=target_name, + node_attr_getters=[], + ) + self.graph2mat_processors[target_name] = data_processor + + # Initialize graph2mat. + self.graph2mats[target_name] = E3nnGraph2Mat( + unique_basis=data_processor.basis_table.basis, + irreps=graph2mat_irreps, + symmetric=data_processor.symmetric_matrix, + basis_grouping=self.hypers["basis_grouping"], + ) + + # The neighbor list options are ignored, since the neighbor lists + # are created by graph2mat according to the basis. + # Here we just make sure we have a unique neighbor list for + # each graph2mat. + self.graph2mat_nls[target_name] = NeighborListOptions( + cutoff=80.999 + i * 0.03, + full_list=True, + strict=True, + requestor=f"graph2mat_{target_name}", + ) + + # --------------------------- + # Outputs definition + # --------------------------- + + all_targets = { + **self.featurizer_dataset_info.targets, + **self.graph2mat_dataset_info.targets, + } + + self.outputs = { + k: ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + ) + for k, target_info in all_targets.items() + } + + # --------------------------- + # Data preprocessing modules + # --------------------------- + + # For now we don't have additive contributions or scaling. + # self.additive_models = torch.nn.ModuleList([]) + + # self.scaler = Scaler(hypers={}, dataset_info=self.featurizer_dataset_info) + + # self.finetune_config: Dict[str, Any] = {} + + def restart(self, dataset_info: DatasetInfo) -> "MetaGraph2Mat": + # Check that the new dataset info does not contain new atomic types + if new_atomic_types := set(dataset_info.atomic_types) - set( + self.dataset_info.atomic_types + ): + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The MACE model does not support adding new atomic types." + ) + + # Merge the old dataset info with the new one + merged_info = self.dataset_info.union(dataset_info) + + # Check if there are new targets + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + # Add extra heads for the new targets + for target_name, target in new_targets.items(): + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + # self.additive_models[0] = self.additive_models[0].restart( + # dataset_info=DatasetInfo( + # length_unit=dataset_info.length_unit, + # atomic_types=self.dataset_info.atomic_types, + # targets={ + # target_name: target_info + # for target_name, target_info in dataset_info.targets.items() + # if CompositionModel.is_valid_target(target_name, target_info) + # }, + # ), + # ) + # self.scaler = self.scaler.restart(dataset_info) + + return self + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + if selected_atoms is not None: + raise NotImplementedError("selected_atoms not implemented yet") + + # ------------------------------------------------------- + # Split outputs according to whether the featurizer or + # graph2mat will handle them + # ------------------------------------------------------- + + featurizer_outputs = { + k: v + for k, v in outputs.items() + if k in self.featurizer_dataset_info.targets + } + graph2mat_outputs = { + k: v for k, v in outputs.items() if k in self.graph2mat_dataset_info.targets + } + + # ----------------------------- + # Featurizer forward pass + # ----------------------------- + # We add extra outputs to the featurizer to retrieve the node + # features that graph2mat will use. + + featurizer_return = self.featurizer_model.forward( + systems=systems, + outputs={ + **featurizer_outputs, + **{ + f"mtt::aux::graph2mat_{target_name}": ModelOutput( + quantity="", + unit="", + per_atom=True, + ) + for target_name in graph2mat_outputs + }, + }, + selected_atoms=selected_atoms, + ) + + # ---------------------------------------------------------------- + # Concatenate tensormap outputs to get flat tensors (e3nn-like) + # ---------------------------------------------------------------- + + graph2mat_inputs = {} + # Concatenate outputs to get the e3nn representations from the tensormap + for target_name in graph2mat_outputs: + graph2mat_inputs[target_name] = [] + + tensormap = featurizer_return.pop(f"mtt::aux::graph2mat_{target_name}") + + for block in tensormap.blocks(): + # Move components dimension to last and then flatten to get (n_atoms, irreps_dim) + block_values = block.values.transpose(1, 2) + graph2mat_inputs[target_name].append( + block_values.reshape(block_values.shape[0], -1) + ) + + graph2mat_inputs[target_name] = torch.cat( + graph2mat_inputs[target_name], dim=-1 + ) + + # ----------------------------- + # Run each Graph2Mat + # ----------------------------- + + graph2mat_returns = {} + + for target_name, graph2mat in self.graph2mats.items(): + if target_name not in graph2mat_outputs: + continue + + # Create the batch with the graph that this graph2mat will use + data = create_batch( + systems=systems, + neighbor_list_options=self.graph2mat_nls[target_name], + atomic_types_to_species_index=self.atomic_types_to_species_index, + n_types=len(self.atomic_types), + data_processor=self.graph2mat_processors[target_name], + ) + + # Convert coordinates from XYZ to YZX so that the outputs are spherical + # harmonics. + data["positions"] = data["positions"][:, [1, 2, 0]] + data["cell"] = data["cell"][:, [1, 2, 0]] + data["shifts"] = data["shifts"][:, [1, 2, 0]] + + # Embed edges and add them to the batch + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding(lengths) + + data["edge_attrs"] = edge_attrs + data["edge_feats"] = edge_feats + + # Run graph2mat and store the outputs (a tuple of tensors: node labels and edge labels) + graph2mat_returns[target_name] = graph2mat( + data=data, node_feats=graph2mat_inputs[target_name] + ) + + # ----------------------------------- + # Convert outputs to TensorMaps + # ----------------------------------- + + # At this point, we have a dictionary of all outputs as normal torch tensors. + # Now, we simply convert to TensorMaps. + + # Get the labels for the samples (system and atom of each value) + + return_dict: Dict[str, TensorMap] = { + **featurizer_return, + **{ + output_name: g2m_labels_to_tensormap( + node_labels=graph2mat_returns[output_name][0], + edge_labels=graph2mat_returns[output_name][1], + ) + for output_name in graph2mat_outputs + }, + } + + return return_dict + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def requested_neighbor_lists( + self, + ) -> List[NeighborListOptions]: + return [self.requested_nl] + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "MetaGraph2Mat": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + if model_state_dict is None: + model_state_dict = checkpoint["model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls(**model_data) + # Infer dtype + dtype = None + # Otherwise, just look at the weights in the state dict + for k, v in model_state_dict.items(): + if k.endswith(".weight"): + dtype = v.dtype + break + else: + raise ValueError("Couldn't infer dtype from the checkpoint file") + + # Set up composition and scaler models + # model.additive_models[0].sync_tensor_maps() + # model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + raise NotImplementedError("Export not implemented yet for MetaGraph2Mat") + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for MACE") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_range = self.hypers["num_interactions"] * self.cutoff + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(jit.compile(self.eval()), metadata, capabilities) + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + + # If the MACE model was passed as part of the hypers, we store it + # again as part of the hypers. + hypers = self.hypers.copy() + + checkpoint = { + "architecture_name": "experimental.graph2mat", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "hypers": hypers, + "dataset_info": self.dataset_info.to(device="cpu"), + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": model_state_dict, + "best_model_state_dict": model_state_dict, + } + return checkpoint + + def get_fixed_scaling_weights(self) -> dict: + return {} diff --git a/src/metatrain/experimental/graph2mat/modules/__init__.py b/src/metatrain/experimental/graph2mat/modules/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/metatrain/experimental/graph2mat/modules/edge_embedding.py b/src/metatrain/experimental/graph2mat/modules/edge_embedding.py new file mode 100644 index 0000000000..4970d3d3b7 --- /dev/null +++ b/src/metatrain/experimental/graph2mat/modules/edge_embedding.py @@ -0,0 +1,42 @@ +import numpy as np +import torch + + +class BesselBasis(torch.nn.Module): + """Embedding of distances using a Bessel basis set.""" + + def __init__(self, r_max: float, num_basis=8, trainable=False): + super().__init__() + + bessel_weights = ( + np.pi + / r_max + * torch.linspace( + start=1.0, + end=num_basis, + steps=num_basis, + dtype=torch.get_default_dtype(), + ) + ) + if trainable: + self.bessel_weights = torch.nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) + + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "prefactor", + torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] + return self.prefactor * (numerator / x) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " + f"trainable={self.bessel_weights.requires_grad})" + ) diff --git a/src/metatrain/experimental/graph2mat/trainer.py b/src/metatrain/experimental/graph2mat/trainer.py new file mode 100644 index 0000000000..3a35000335 --- /dev/null +++ b/src/metatrain/experimental/graph2mat/trainer.py @@ -0,0 +1,571 @@ +import copy +import logging +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union + +import torch +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import MetaGraph2Mat +from .utils.dataset import get_graph2mat_transform +from .utils.mtt import g2m_labels_to_tensormap + + +def filter_out_nans( + targets: Dict[str, Any], predictions: Optional[Dict[str, Any]] = None +) -> None: + """Filter out NaN values from the targets and predictions. + + :param targets: Dictionary of target TensorMaps. + :param predictions: Dictionary of prediction TensorMaps. + """ + + clean_targets = {} + clean_predictions = {} if predictions is not None else None + for key in targets.keys(): + target_tensormap = targets[key] + filtered_targets = [] + filtered_predictions = [] + for i, block in enumerate(target_tensormap.blocks()): + mask = torch.isnan(block.values) + filtered_targets.append(block.values[~mask]) + if predictions is not None: + filtered_predictions.append(predictions[key].block(i).values[~mask]) + + clean_targets[key] = g2m_labels_to_tensormap( + node_labels=filtered_targets[0], + edge_labels=filtered_targets[1], + ) + + if predictions is not None: + clean_predictions[key] = g2m_labels_to_tensormap( + node_labels=filtered_predictions[0], + edge_labels=filtered_predictions[1], + ) + + return clean_targets, clean_predictions + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 1 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: MetaGraph2Mat, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in MetaGraph2Mat.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with MetaMACE, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models are always in float64 (to avoid numerical + # errors in the composition weights, which can be very large). + # for additive_model in model.additive_models: + # additive_model.to(dtype=torch.float64) + # model.scaler.to(dtype=torch.float64) + + # if self.hypers["scale_targets"] + # logging.info("Calculating scaling weights") + # model.scaler.train_model( + # train_datasets, + # model.additive_models, + # self.hypers["batch_size"], + # is_distributed, + # { + # **model.get_fixed_scaling_weights(), + # **self.hypers["fixed_scaling_weights"], + # }, + # ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + # model.scaler.scales_to(device="cpu", dtype=torch.float64) + # scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + # model.scaler.to(device) + # model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + requested_neighbor_lists = get_requested_neighbor_lists(model.featurizer_model) + collate_fn = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + # get_remove_scale_transform(scaler), + get_graph2mat_transform( + model.graph2mat_processors, model.graph2mat_nls + ), + ], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + if len(val_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A validation dataset has fewer samples " + f"({len(val_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = self.hypers["loss"] + loss_fn = LossAggregator( + targets=train_targets, + config=loss_hypers, + ) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + optimizer = getattr(torch.optim, self.hypers["optimizer"])( + model.parameters(), **self.hypers["optimizer_kwargs"] + ) + + lr_scheduler = None + if self.hypers["lr_scheduler"] is not None: + lr_scheduler = getattr( + torch.optim.lr_scheduler, self.hypers["lr_scheduler"] + )( + optimizer, + **self.hypers["lr_scheduler_kwargs"], + ) + + per_structure_targets = [ + *model.graph2mat_dataset_info.targets, + *self.hypers["per_structure_targets"], + ] + + # Log the initial learning rate: + logging.info(f"Base learning rate: {optimizer.param_groups[0]['lr']}") + + start_epoch = 0 if self.epoch is None else self.epoch + 1 + + # Train the model: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Starting training") + epoch = start_epoch + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + targets, predictions = filter_out_nans(targets, predictions) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make torch DDP happy + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + # scaled_predictions = (model.module if is_distributed else model).scaler( + # systems, predictions + # ) + # scaled_targets = (model.module if is_distributed else model).scaler( + # systems, targets + # ) + scaled_predictions = predictions + scaled_targets = targets + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + targets, predictions = filter_out_nans(targets, predictions) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + # scaled_predictions = (model.module if is_distributed else model).scaler( + # systems, predictions + # ) + # scaled_targets = (model.module if is_distributed else model).scaler( + # systems, targets + # ) + scaled_predictions = predictions + scaled_targets = targets + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + if lr_scheduler is not None: + lr_scheduler.step(metrics=val_loss) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = { + "loss": train_loss, + **finalized_train_info, + } + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, + } + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + checkpoint.update( + { + "trainer_ckpt_version": self.__checkpoint_version__, + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/experimental/graph2mat/utils/__init__.py b/src/metatrain/experimental/graph2mat/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/metatrain/experimental/graph2mat/utils/basis.py b/src/metatrain/experimental/graph2mat/utils/basis.py new file mode 100644 index 0000000000..4ff520abde --- /dev/null +++ b/src/metatrain/experimental/graph2mat/utils/basis.py @@ -0,0 +1,25 @@ +import numpy as np +import yaml +from graph2mat import AtomicTableWithEdges, PointBasis + + +def get_basis_table_from_yaml(basis_yaml: str) -> AtomicTableWithEdges: + """Reads a yaml file and initializes an AtomicTableWithEdges object. + + :param basis_yaml: Path to the yaml file. + :return: The corresponding AtomicTableWithEdges object. + """ + + # Load the yaml basis file + with open(basis_yaml, "r") as f: + basis_yaml = yaml.safe_load(f) + + basis = [] + for point_basis in basis_yaml: + if isinstance(point_basis["R"], list): + point_basis["R"] = np.array(point_basis["R"]) + if isinstance(point_basis["basis"], list): + point_basis["basis"] = tuple(tuple(x) for x in point_basis["basis"]) + basis.append(PointBasis(**point_basis).to_sisl_atom(Z=point_basis["type"])) + + return AtomicTableWithEdges(basis) diff --git a/src/metatrain/experimental/graph2mat/utils/dataset.py b/src/metatrain/experimental/graph2mat/utils/dataset.py new file mode 100644 index 0000000000..c157f134e7 --- /dev/null +++ b/src/metatrain/experimental/graph2mat/utils/dataset.py @@ -0,0 +1,237 @@ +import warnings +from collections import defaultdict +from typing import Callable, Optional + +import graph2mat +import sisl +import torch +from e3nn import o3 +from graph2mat import ( + AtomicTableWithEdges, + BasisConfiguration, + BasisMatrix, + MatrixDataProcessor, +) +from graph2mat.bindings.torch import TorchBasisMatrixDataset +from graph2mat.core.data.basis import NoBasisAtom, get_change_of_basis +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatomic.torch import NeighborListOptions, System, register_autograd_neighbors + +from metatrain.utils.data import system_to_ase + +from .mtt import g2m_labels_to_tensormap + + +def system_to_config( + system: System, + data_processor: MatrixDataProcessor, + block_dict: Optional[dict[tuple[int, int, int], torch.Tensor]] = None, +) -> BasisConfiguration: + """Convert a Metatomic System to a Graph2Mat BasisConfiguration.""" + + basis = data_processor.basis_table.atoms + + geometry = sisl.Geometry.new(system_to_ase(system)) + + for atom in geometry.atoms.atom: + for basis_atom in basis: + if basis_atom.tag == atom.tag: + break + else: + basis_atom = NoBasisAtom(atom.Z, tag=atom.tag) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + geometry.atoms.replace_atom(atom, basis_atom) + + if block_dict is None: + matrix = None + else: + matrix = BasisMatrix(block_dict, geometry.nsc, geometry.orbitals) + + return BasisConfiguration.from_geometry(geometry, matrix=matrix) + + +def get_converters_to_spherical( + basis_table: AtomicTableWithEdges, +) -> dict[int, torch.Tensor]: + """Get the converters to spherical harmonics for each basis type in the basis table.""" + + cob, _ = get_change_of_basis(basis_table.basis_convention, "spherical") + + p_change_of_basis = torch.tensor(cob) + d_change_of_basis = o3.Irrep(2, 1).D_from_matrix(p_change_of_basis) + + converters = {} + for point_basis in basis_table.basis: + M = torch.eye(point_basis.basis_size) + + start = 0 + for mul, l, p in point_basis.basis: + if p != (-1) ** l: + raise ValueError( + "Only spherical basis with definite parity are supported." + ) + + for i in range(mul): + end = start + (2 * l + 1) + if l == 1: + M[start:end, start:end] = p_change_of_basis + elif l == 2: + M[start:end, start:end] = d_change_of_basis + else: + M[start:end, start:end] = o3.Irrep(l, 1).D_from_matrix( + p_change_of_basis + ) + start = end + + converters[point_basis.type] = M + + return converters + + +def get_graph2mat_transform( + graph2mat_processors: dict[str, MatrixDataProcessor], + nls_options: dict[str, NeighborListOptions], +) -> Callable: + """Returns a transform function that processes systems and targets + to adapt them to Graph2Mat. + + Essentially, a graph2mat batch is just a flat array, and in this + transform we convert the target in the metatrain format to this + flat array format. + + Also, each graph2mat instance will require a different graph (which + can be different from the graph used by the featurizer). Therefore, + we also compute and add the neighbor lists required by each graph2mat + instance. + """ + + converters = {} + for target_name in graph2mat_processors: + converters[target_name] = get_converters_to_spherical( + graph2mat_processors[target_name].basis_table + ) + + print(converters[target_name]) + + def transform( + systems: list[System], + targets: dict[str, TensorMap], + extra: dict[str, TensorMap], + ) -> tuple[list[System], dict[str, TensorMap], dict[str, TensorMap]]: + system_indices = ( + extra["system_index"].block(0).values.ravel().to(torch.int64).tolist() + ) + + for target_name in targets: + configs = [ + system_to_config(system, graph2mat_processors[target_name], None) + for system in systems + ] + + dataset = TorchBasisMatrixDataset( + configs, + data_processor=graph2mat_processors[target_name], + data_cls=graph2mat.bindings.torch.TorchBasisMatrixData, + load_labels=False, + ) + + lattices = { + i: sisl.Lattice(data.cell.numpy(), data.nsc.reshape(3).numpy()) + for i, data in zip(system_indices, dataset, strict=True) + } + + block_dict_matrices = defaultdict(dict) + + tensormap_matrix = targets[target_name] + + for key in tensormap_matrix.keys: + block = tensormap_matrix[key] + + block_values = block.values + first_atom_type, second_atom_type = key + conv_left = converters[target_name][int(first_atom_type)].to( + block_values.dtype + ) + conv_right = converters[target_name][int(second_atom_type)].to( + block_values.dtype + ) + + block_values = torch.einsum( + "ij, bjkp, kl -> bilp", conv_left, block_values, conv_right.T + ) + + samples = block.samples.values + + for i_pair, ( + i_system, + first_atom, + second_atom, + *cell_shifts, + ) in enumerate(samples): + cell_index = lattices[int(i_system)].isc_off[ + cell_shifts[0], cell_shifts[1], cell_shifts[2] + ] + + block_dict_matrices[int(i_system)][ + (int(first_atom), int(second_atom), int(cell_index)) + ] = block_values[i_pair].squeeze(-1) + + configs = [ + system_to_config( + system, graph2mat_processors[target_name], block_dict_matrices[i] + ) + for i, system in zip(system_indices, systems, strict=True) + ] + + dataset = TorchBasisMatrixDataset( + configs, + data_processor=graph2mat_processors[target_name], + data_cls=graph2mat.bindings.torch.TorchBasisMatrixData, + load_labels=True, + ) + + all_point_labels = [] + all_edge_labels = [] + + for i, data in enumerate(dataset): + all_point_labels.append(data.point_labels) + all_edge_labels.append(data.edge_labels) + + edge_index = data.edge_index + distances = data.positions[edge_index].diff(dim=0) + + cell_shifts = sisl.Lattice( + data.cell.numpy(), data.nsc.reshape(3).numpy() + ).sc_off[data.neigh_isc] + + neighbor_list = TensorBlock( + values=distances.reshape(-1, 3, 1).to(systems[i].positions.dtype), + samples=Labels( + names=[ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + values=torch.hstack([edge_index.T, torch.tensor(cell_shifts)]), + assume_unique=True, + ), + components=[Labels.range("xyz", 3)], + properties=Labels.range("distance", 1), + ) + + register_autograd_neighbors(systems[i], neighbor_list) + systems[i].add_neighbor_list(nls_options[target_name], neighbor_list) + + targets[target_name] = g2m_labels_to_tensormap( + node_labels=torch.cat(all_point_labels, dim=0), + edge_labels=torch.cat(all_edge_labels, dim=0), + i=i, + ) + + return systems, targets, extra + + return transform diff --git a/src/metatrain/experimental/graph2mat/utils/mtt.py b/src/metatrain/experimental/graph2mat/utils/mtt.py new file mode 100644 index 0000000000..5825e52324 --- /dev/null +++ b/src/metatrain/experimental/graph2mat/utils/mtt.py @@ -0,0 +1,195 @@ +import torch +from e3nn import o3 +from metatensor.torch import Labels, TensorBlock, TensorMap +from omegaconf import DictConfig + +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.data.target_info import _REGISTERED_TARGET_TYPES + + +def get_mts_components(point_basis): + component_names = ["o3_lambda", "o3_sigma", "o3_mu", "i_zeta"] + + component_values = [] + for mul, l, sigma in point_basis.basis: + for i_zeta in range(mul): + for m in range(-l, l + 1): + component_values.append([l, sigma, m, i_zeta]) + + return Labels( + names=component_names, + values=torch.tensor(component_values, dtype=torch.int32), + ) + + +def _get_basis_target_info(target_name: str, target: DictConfig) -> TargetInfo: + keys = list([[0, 0]]) + basis_size = [5] + + layout = TensorMap( + keys=Labels(["first_atom_type", "second_atom_type"], torch.tensor(keys)), + blocks=[ + TensorBlock( + values=torch.empty( + 0, basis_size[type_0], basis_size[type_1], 1, dtype=torch.float64 + ), + samples=Labels( + names=[ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + values=torch.empty((0, 6), dtype=torch.int32), + ), + components=[ + Labels( + ["first_atom_basis_function"], + torch.arange(basis_size[type_0]).reshape(-1, 1), + ), + Labels( + ["second_atom_basis_function"], + torch.arange(basis_size[type_1]).reshape(-1, 1), + ), + ], + properties=Labels(["_"], torch.tensor([[0]])), + ) + for type_0, type_1 in keys + ], + ) + + return TargetInfo( + layout=layout, + quantity=target.get("quantity", ""), + unit=target.get("unit", ""), + description=target.get("description", ""), + ) + + +_REGISTERED_TARGET_TYPES["basis"] = _get_basis_target_info + + +def _wrap_in_tensorblock(data_values, i): + return TensorBlock( + values=data_values.reshape(-1, 1).to(torch.float64), + samples=Labels( + names=["system", "matrix_element"], + values=torch.tensor( + [[i] * data_values.shape[0], torch.arange(data_values.shape[0])] + ).T, + ), + components=[], + properties=Labels(["_"], torch.tensor([[0]])), + ) + + +def g2m_labels_to_tensormap( + node_labels: torch.Tensor, edge_labels: torch.Tensor, i: int = 0 +) -> TensorMap: + return TensorMap( + keys=Labels(["graph2mat_point_or_edge"], torch.tensor([[0], [1]])), + blocks=[ + _wrap_in_tensorblock(node_labels, i), + _wrap_in_tensorblock(edge_labels, i), + ], + ) + + +def get_e3nn_target_info(target_name: str, target: dict) -> TargetInfo: + """Get the target info corresponding to some e3nn irreps. + + :param target_name: Name of the target. + :param target: Target dictionary containing the irreps and other info. + :return: The corresponding TargetInfo object. + """ + sample_names = ["system"] + if target["per_atom"]: + sample_names.append("atom") + + properties_name = target.get("properties_name", target_name.replace("mtt::", "")) + + irreps = o3.Irreps(target["type"]["spherical"]["irreps"]) + keys = [] + blocks = [] + for irrep in irreps: + o3_lambda = irrep.ir.l + o3_sigma = irrep.ir.p * ((-1) ** o3_lambda) + num_properties = irrep.mul + + components = [ + Labels( + names=["o3_mu"], + values=torch.arange( + -o3_lambda, o3_lambda + 1, dtype=torch.int32 + ).reshape(-1, 1), + ) + ] + block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty( + 0, + 2 * o3_lambda + 1, + num_properties, + dtype=torch.float64, + ), + samples=Labels( + names=sample_names, + values=torch.empty((0, len(sample_names)), dtype=torch.int32), + ), + components=components, + properties=Labels.range(properties_name, num_properties), + ) + keys.append([o3_lambda, o3_sigma]) + blocks.append(block) + + layout = TensorMap( + keys=Labels(["o3_lambda", "o3_sigma"], torch.tensor(keys, dtype=torch.int32)), + blocks=blocks, + ) + + target_info = TargetInfo( + quantity=target.get("quantity", ""), + unit=target.get("unit", ""), + layout=layout, + ) + return target_info + + +def split_dataset_info(dataset_info: DatasetInfo, node_hidden_irreps: str): + """Splits the dataset info into one info for the featurizer and one for graph2mat.""" + graph2mat_targets = {} + featurizer_targets = {} + for target_name, target_info in dataset_info.targets.items(): + if target_info.layout.keys.names == ["first_atom_type", "second_atom_type"]: + graph2mat_targets[target_name] = target_info + + featurizer_targets[f"mtt::aux::graph2mat_{target_name}"] = ( + get_e3nn_target_info( + target_name=f"mtt::aux::graph2mat_{target_name}", + target={ + "type": {"spherical": {"irreps": node_hidden_irreps}}, + "quantity": "", + "unit": "", + "per_atom": True, + "properties_name": "_", + }, + ) + ) + else: + featurizer_targets[target_name] = target_info + + featurizer_dataset_info = DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=dataset_info.atomic_types, + targets=featurizer_targets, + ) + + graph2mat_dataset_info = DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=dataset_info.atomic_types, + targets=graph2mat_targets, + ) + + return featurizer_dataset_info, graph2mat_dataset_info diff --git a/src/metatrain/experimental/graph2mat/utils/structures.py b/src/metatrain/experimental/graph2mat/utils/structures.py new file mode 100644 index 0000000000..7a13eac518 --- /dev/null +++ b/src/metatrain/experimental/graph2mat/utils/structures.py @@ -0,0 +1,122 @@ +from typing import List + +import sisl +import torch +from graph2mat import MatrixDataProcessor +from metatomic.torch import NeighborListOptions, System + + +def create_batch( + systems: List[System], + neighbor_list_options: NeighborListOptions, + atomic_types_to_species_index: torch.Tensor, + n_types: int, + data_processor: MatrixDataProcessor, +) -> dict[str, torch.Tensor]: + """Creates a torch geometric-like batch from a list of systems. + + The batch returned by this function can be used as input + for MACE models. + + :param systems: List of systems to batch. + :param neighbor_list_options: Options to create the neighbor lists. + :param atomic_types_to_species_index: Mapping from atomic types to species index. + :param n_types: Number of different species. + + :return: A dictionary containing the batched data. + """ + unit_shifts = [] + cell_shifts = [] + edge_index = [] + atom_types = [] + edge_types = [] + neigh_isc = [] + batch = [] + system_start_index = [0] + + dtype = systems[0].positions.dtype + device = systems[0].device + + for system_i, system in enumerate(systems): + neighbors = system.get_neighbor_list(neighbor_list_options) + start_index = system_start_index[-1] + + # TODO: make this faster? + system_atom_types = atomic_types_to_species_index[system.types] + atom_types.append(system_atom_types) + + shifts = neighbors.samples.view( + ["cell_shift_a", "cell_shift_b", "cell_shift_c"] + ).values.T + + system_edge_index = neighbors.samples.view( + ["first_atom", "second_atom"] + ).values.T.to(torch.int64) + system_cell_shifts = shifts.T.to(dtype) @ system.cell + + # Get the edge types + system_edge_types = data_processor.basis_table.point_type_to_edge_type( + system_atom_types[system_edge_index] + ) + + # Check if there are any edges + any_edges = system_edge_index.shape[1] > 0 + + # Get the number of supercells needed along each direction to account for all interactions + if any_edges: + nsc = abs(shifts).max(axis=1).values * 2 + 1 + else: + nsc = torch.tensor([1, 1, 1]) + + # Then build the supercell that encompasses all of those atoms, so that we can get the + # array that converts from sc shifts (3D) to a single supercell index. This is isc_off. + supercell = sisl.Lattice(system.cell, nsc=nsc) + + edge_index.append(system_edge_index + start_index) + edge_types.append(torch.from_numpy(system_edge_types).to(torch.int64)) + cell_shifts.append(system_cell_shifts) + unit_shifts.append(shifts.T) + + # Then, get the supercell index of each interaction. + neigh_isc.append( + torch.tensor(supercell.isc_off[shifts[0], shifts[1], shifts[2]]) + ) + + n_atoms = len(system) + batch.append(torch.full((n_atoms,), system_i)) + system_start_index.append(start_index + n_atoms) + + return { + "positions": torch.vstack([s.positions for s in systems]), + "cell": torch.vstack([s.cell for s in systems]), + "unit_shifts": torch.vstack(unit_shifts), + "edge_index": torch.hstack(edge_index), + "shifts": torch.vstack(cell_shifts), + "head": torch.tensor([0] * len(systems)).to(device), + "batch": torch.hstack(batch).to(device), + "ptr": torch.tensor(system_start_index).to(device), + "node_attrs": torch.nn.functional.one_hot( + torch.hstack(atom_types), num_classes=n_types + ).to(dtype), + "point_types": torch.hstack(atom_types), + "edge_types": torch.hstack(edge_types), + "neigh_isc": torch.hstack(neigh_isc), + } + + +def get_edge_vectors_and_lengths( + positions: torch.Tensor, # [n_nodes, 3] + edge_index: torch.Tensor, # [2, n_edges] + shifts: torch.Tensor, # [n_edges, 3] + normalize: bool = False, + eps: float = 1e-9, +) -> tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + if normalize: + vectors_normed = vectors / (lengths + eps) + return vectors_normed, lengths + + return vectors, lengths diff --git a/src/metatrain/share/base_hypers.py b/src/metatrain/share/base_hypers.py index 6a12d1ebad..6a2a2cb802 100644 --- a/src/metatrain/share/base_hypers.py +++ b/src/metatrain/share/base_hypers.py @@ -75,6 +75,7 @@ class GradientDict(TypedDict): ScalarTargetTypeHyper = Literal["scalar"] +BasisTargetTypeHyper = Literal["basis"] @with_config(ConfigDict(extra="forbid", strict=True)) @@ -143,7 +144,10 @@ class TargetHypers(TypedDict): """Whether the target is a per-atom quantity, as opposed to a global (per-structure) quantity.""" type: NotRequired[ - ScalarTargetTypeHyper | CartesianTargetTypeHypers | SphericalTargetTypeHypers + ScalarTargetTypeHyper + | CartesianTargetTypeHypers + | SphericalTargetTypeHypers + | BasisTargetTypeHyper ] """Specifies the type of the target. diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 1f8325bf7a..e7a2f0cefd 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -705,7 +705,9 @@ def get_target_info(self, target_config: DictConfig) -> Dict[str, TargetInfo]: target_info_dict[target_key] = target_info else: target_info = get_generic_target_info(target_key, target) - _check_tensor_map_metadata(tensor_map, target_info.layout) + + if not target_info.is_basis: + _check_tensor_map_metadata(tensor_map, target_info.layout) # make sure that the properties of the target_info.layout also match the # actual properties of the tensor maps target_info.layout = _empty_tensor_map_like(tensor_map) diff --git a/src/metatrain/utils/data/target_info.py b/src/metatrain/utils/data/target_info.py index 367c9b605b..1213890b66 100644 --- a/src/metatrain/utils/data/target_info.py +++ b/src/metatrain/utils/data/target_info.py @@ -44,6 +44,7 @@ def __init__( self.is_scalar = False self.is_cartesian = False self.is_spherical = False + self.is_basis = False self._check_layout(layout) self.layout = layout @@ -107,15 +108,31 @@ def _check_layout(self, layout: TensorMap) -> None: :param layout: The layout TensorMap to check. """ + valid_sample_names = [ + ["system"], + [ + "system", + "atom", + ], + [ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + ] + + if layout.sample_names not in valid_sample_names: + raise ValueError( + "The layout ``TensorMap`` of a target should only have samples " + f"named as one of: {valid_sample_names}, but found " + f"'{layout.sample_names}' instead." + ) + # examine basic properties of all blocks for block in layout.blocks(): - for sample_name in block.samples.names: - if sample_name not in ["system", "atom"]: - raise ValueError( - "The layout ``TensorMap`` of a target should only have samples " - "named 'system' or 'atom', but found " - f"'{sample_name}' instead." - ) if len(block.values) != 0: raise ValueError( "The layout ``TensorMap`` of a target should have 0 " @@ -140,6 +157,10 @@ def _check_layout(self, layout: TensorMap) -> None: and components_first_block[0].names[0] == "o3_mu" ): self.is_spherical = True + elif len(components_first_block) == 2 and components_first_block[0].names[ + 0 + ].endswith("basis_function"): + self.is_basis = True else: raise ValueError( "The layout ``TensorMap`` of a target should be " @@ -393,16 +414,22 @@ def get_generic_target_info(target_name: str, target: DictConfig) -> TargetInfo: :return: A `TargetInfo` with the layout of the target. """ - if target["type"] == "scalar": - return _get_scalar_target_info(target_name, target) - elif len(target["type"]) == 1 and next(iter(target["type"])).lower() == "cartesian": - return _get_cartesian_target_info(target_name, target) - elif len(target["type"]) == 1 and next(iter(target["type"])) == "spherical": - return _get_spherical_target_info(target_name, target) + if isinstance(target["type"], str): + target_type = target["type"].lower() + elif len(target["type"]) == 1 and isinstance(next(iter(target["type"])), str): + target_type = next(iter(target["type"])).lower() + else: + raise ValueError( + "Couldn't infer target type from the 'type' field of the target configuration." + f" Found: {target['type']}" + ) + + if target_type in _REGISTERED_TARGET_TYPES: + return _REGISTERED_TARGET_TYPES[target_type](target_name, target) else: raise ValueError( f"Target type {target['type']} is not supported. " - "Supported types are 'scalar', 'cartesian' and 'spherical'." + f"Supported types are {list(_REGISTERED_TARGET_TYPES.keys())}." ) @@ -533,6 +560,13 @@ def _get_spherical_target_info(target_name: str, target: DictConfig) -> TargetIn ) +_REGISTERED_TARGET_TYPES = { + "scalar": _get_scalar_target_info, + "cartesian": _get_cartesian_target_info, + "spherical": _get_spherical_target_info, +} + + def is_auxiliary_output(name: str) -> bool: """ Check if a target name corresponds to an auxiliary output. From 89f398db98445bdd26d8635a10e3b9633137daa3 Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Tue, 16 Dec 2025 17:37:25 +0100 Subject: [PATCH 2/5] Make it work on GPU --- .../experimental/graph2mat/utils/dataset.py | 2 -- .../experimental/graph2mat/utils/mtt.py | 10 ++++-- .../graph2mat/utils/structures.py | 31 ++++++++++--------- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/metatrain/experimental/graph2mat/utils/dataset.py b/src/metatrain/experimental/graph2mat/utils/dataset.py index c157f134e7..edd2c6df7a 100644 --- a/src/metatrain/experimental/graph2mat/utils/dataset.py +++ b/src/metatrain/experimental/graph2mat/utils/dataset.py @@ -113,8 +113,6 @@ def get_graph2mat_transform( graph2mat_processors[target_name].basis_table ) - print(converters[target_name]) - def transform( systems: list[System], targets: dict[str, TensorMap], diff --git a/src/metatrain/experimental/graph2mat/utils/mtt.py b/src/metatrain/experimental/graph2mat/utils/mtt.py index 5825e52324..11cdd71146 100644 --- a/src/metatrain/experimental/graph2mat/utils/mtt.py +++ b/src/metatrain/experimental/graph2mat/utils/mtt.py @@ -77,11 +77,12 @@ def _wrap_in_tensorblock(data_values, i): samples=Labels( names=["system", "matrix_element"], values=torch.tensor( - [[i] * data_values.shape[0], torch.arange(data_values.shape[0])] + [[i] * data_values.shape[0], torch.arange(data_values.shape[0])], + device=data_values.device, ).T, ), components=[], - properties=Labels(["_"], torch.tensor([[0]])), + properties=Labels(["_"], torch.tensor([[0]], device=data_values.device)), ) @@ -89,7 +90,10 @@ def g2m_labels_to_tensormap( node_labels: torch.Tensor, edge_labels: torch.Tensor, i: int = 0 ) -> TensorMap: return TensorMap( - keys=Labels(["graph2mat_point_or_edge"], torch.tensor([[0], [1]])), + keys=Labels( + ["graph2mat_point_or_edge"], + torch.tensor([[0], [1]], device=node_labels.device), + ), blocks=[ _wrap_in_tensorblock(node_labels, i), _wrap_in_tensorblock(edge_labels, i), diff --git a/src/metatrain/experimental/graph2mat/utils/structures.py b/src/metatrain/experimental/graph2mat/utils/structures.py index 7a13eac518..fbaa50e2ad 100644 --- a/src/metatrain/experimental/graph2mat/utils/structures.py +++ b/src/metatrain/experimental/graph2mat/utils/structures.py @@ -55,13 +55,24 @@ def create_batch( system_cell_shifts = shifts.T.to(dtype) @ system.cell # Get the edge types - system_edge_types = data_processor.basis_table.point_type_to_edge_type( - system_atom_types[system_edge_index] + edge_type = torch.tensor(data_processor.basis_table.edge_type).to( + device, torch.int64 ) + edge_atoms = system_atom_types[system_edge_index] + system_edge_types = edge_type[edge_atoms[0], edge_atoms[1]] # Check if there are any edges any_edges = system_edge_index.shape[1] > 0 + edge_index.append(system_edge_index + start_index) + edge_types.append(system_edge_types) + cell_shifts.append(system_cell_shifts) + unit_shifts.append(shifts.T) + + n_atoms = len(system) + batch.append(torch.full((n_atoms,), system_i)) + system_start_index.append(start_index + n_atoms) + # Get the number of supercells needed along each direction to account for all interactions if any_edges: nsc = abs(shifts).max(axis=1).values * 2 + 1 @@ -70,21 +81,11 @@ def create_batch( # Then build the supercell that encompasses all of those atoms, so that we can get the # array that converts from sc shifts (3D) to a single supercell index. This is isc_off. - supercell = sisl.Lattice(system.cell, nsc=nsc) - - edge_index.append(system_edge_index + start_index) - edge_types.append(torch.from_numpy(system_edge_types).to(torch.int64)) - cell_shifts.append(system_cell_shifts) - unit_shifts.append(shifts.T) + supercell = sisl.Lattice(system.cell.cpu(), nsc=nsc.cpu()) # Then, get the supercell index of each interaction. - neigh_isc.append( - torch.tensor(supercell.isc_off[shifts[0], shifts[1], shifts[2]]) - ) - - n_atoms = len(system) - batch.append(torch.full((n_atoms,), system_i)) - system_start_index.append(start_index + n_atoms) + isc_off = torch.from_numpy(supercell.isc_off).to(device) + neigh_isc.append(isc_off[shifts[0], shifts[1], shifts[2]]) return { "positions": torch.vstack([s.positions for s in systems]), From 46542c95b660491ef4b66a8c82c3080a88373724 Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Tue, 16 Dec 2025 17:47:27 +0100 Subject: [PATCH 3/5] Add extra dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f9bb128ff8..e5ce459884 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ mace = [ "e3nn" ] graph2mat = [ - "graph2mat", + "graph2mat[e3nn]", "e3nn", "sisl", ] From b19d5085e50ba46e3d81b89fe597e2cc9e0291fb Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Mon, 5 Jan 2026 12:44:57 +0100 Subject: [PATCH 4/5] Make distributed work --- src/metatrain/experimental/graph2mat/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metatrain/experimental/graph2mat/trainer.py b/src/metatrain/experimental/graph2mat/trainer.py index 3a35000335..a61ce5a210 100644 --- a/src/metatrain/experimental/graph2mat/trainer.py +++ b/src/metatrain/experimental/graph2mat/trainer.py @@ -299,7 +299,7 @@ def train( ) per_structure_targets = [ - *model.graph2mat_dataset_info.targets, + *(model.module if is_distributed else model).graph2mat_dataset_info.targets, *self.hypers["per_structure_targets"], ] From 8ba4ed65d2a456fa71bfdfbfcda0eaf3cd104e18 Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Tue, 6 Jan 2026 17:54:45 +0100 Subject: [PATCH 5/5] Reorganized data handling and added temporary eval script --- .../experimental/graph2mat/__init__.py | 1 + src/metatrain/experimental/graph2mat/model.py | 8 +- .../graph2mat/utils/conversions.py | 177 ++++++++++++++++++ .../experimental/graph2mat/utils/dataset.py | 131 +++++++------ .../experimental/graph2mat/utils/eval.py | 159 ++++++++++++++++ 5 files changed, 412 insertions(+), 64 deletions(-) create mode 100644 src/metatrain/experimental/graph2mat/utils/conversions.py create mode 100644 src/metatrain/experimental/graph2mat/utils/eval.py diff --git a/src/metatrain/experimental/graph2mat/__init__.py b/src/metatrain/experimental/graph2mat/__init__.py index 18522d46ea..735a83229c 100644 --- a/src/metatrain/experimental/graph2mat/__init__.py +++ b/src/metatrain/experimental/graph2mat/__init__.py @@ -1,5 +1,6 @@ from .model import MetaGraph2Mat from .trainer import Trainer +from .utils.conversions import * from .utils.mtt import _get_basis_target_info diff --git a/src/metatrain/experimental/graph2mat/model.py b/src/metatrain/experimental/graph2mat/model.py index c7f7b7475c..55c81a5bcd 100644 --- a/src/metatrain/experimental/graph2mat/model.py +++ b/src/metatrain/experimental/graph2mat/model.py @@ -386,17 +386,11 @@ def load_checkpoint( model = cls(**model_data) # Infer dtype dtype = None - # Otherwise, just look at the weights in the state dict - for k, v in model_state_dict.items(): - if k.endswith(".weight"): - dtype = v.dtype - break - else: - raise ValueError("Couldn't infer dtype from the checkpoint file") # Set up composition and scaler models # model.additive_models[0].sync_tensor_maps() # model.scaler.sync_tensor_maps() + model.load_state_dict(model_state_dict) # Loading the metadata from the checkpoint model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) diff --git a/src/metatrain/experimental/graph2mat/utils/conversions.py b/src/metatrain/experimental/graph2mat/utils/conversions.py new file mode 100644 index 0000000000..a6ca3903ef --- /dev/null +++ b/src/metatrain/experimental/graph2mat/utils/conversions.py @@ -0,0 +1,177 @@ +from collections import defaultdict + +import numpy as np +import sisl +import torch +from e3nn import o3 +from graph2mat import ( + Formats, + conversions, +) +from graph2mat.core.data.basis import get_change_of_basis +from metatensor.torch import Labels, TensorBlock, TensorMap + + +Formats.TENSORMAP = "tensormap" + + +@conversions.converter(Formats.BASISCONFIGURATION, Formats.TENSORMAP) +def basisconfiguration_to_tensormap(config, i: int = 0): + tensorblocks = defaultdict(list) + tensorblocks_samples = defaultdict(list) + block_shapes = {} + + block_dict_matrix = config.matrix.block_dict + lattice = sisl.Lattice(config.cell, nsc=config.matrix.nsc) + + for k, v in block_dict_matrix.items(): + first_atom, second_atom, cell_index = k + + cell_shift = lattice.sc_off[cell_index] + + center_type = config.point_types[first_atom] + neighbor_type = config.point_types[second_atom] + + tensorblocks[(center_type, neighbor_type)].append(v) + tensorblocks_samples[(center_type, neighbor_type)].append( + [i, first_atom, second_atom, *cell_shift] + ) + + block_shapes[(center_type, neighbor_type)] = v.shape + + tensorblocks = {k: np.array(v) for k, v in tensorblocks.items()} + + keys = list(tensorblocks.keys()) + + return TensorMap( + keys=Labels(["first_atom_type", "second_atom_type"], torch.tensor(keys)), + blocks=[ + TensorBlock( + values=torch.tensor(tensorblocks[key]) + .reshape(-1, block_shapes[key][0], block_shapes[key][1], 1) + .to(torch.float64), + samples=Labels( + names=[ + "system", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + values=torch.tensor(tensorblocks_samples[key]), + ), + components=[ + Labels( + ["first_atom_basis_function"], + torch.arange(block_shapes[key][0]).reshape(-1, 1), + ), + Labels( + ["second_atom_basis_function"], + torch.arange(block_shapes[key][1]).reshape(-1, 1), + ), + ], + properties=Labels(["_"], torch.tensor([[0]])), + ) + for key in keys + ], + ) + + +def get_target_converters( + basis_table, + in_format: str, + out_format: str, +) -> dict[int, torch.Tensor]: + """Get the converters from spherical harmonics for each basis type in the basis table.""" + + cob, _ = get_change_of_basis(in_format, out_format) + + p_change_of_basis = torch.tensor(cob) + d_change_of_basis = o3.Irrep(2, 1).D_from_matrix(p_change_of_basis) + + converters = {} + for point_basis in basis_table.basis: + M = torch.eye(point_basis.basis_size) + + start = 0 + for mul, l, p in point_basis.basis: + if p != (-1) ** l: + raise ValueError( + "Only spherical basis with definite parity are supported." + ) + + for i in range(mul): + end = start + (2 * l + 1) + if l == 1: + M[start:end, start:end] = p_change_of_basis + elif l == 2: + M[start:end, start:end] = d_change_of_basis + else: + M[start:end, start:end] = o3.Irrep(l, 1).D_from_matrix( + p_change_of_basis + ) + start = end + + converters[point_basis.type] = M + + return converters + + +def transform_tensormap_matrix( + tmap: TensorMap, + converters: dict[int, torch.Tensor], +) -> TensorMap: + transformed_blocks = [] + + for key in tmap.keys: + block = tmap[key] + + block_values = block.values + first_atom_type, second_atom_type = key + + conv_left = converters[int(first_atom_type)].to(block_values.dtype) + conv_right = converters[int(second_atom_type)].to(block_values.dtype) + + transformed_values = torch.einsum( + "ij, bjkp, kl -> bilp", conv_left, block_values, conv_right.T + ) + + transformed_block = TensorBlock( + values=transformed_values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + + transformed_blocks.append(transformed_block) + + return TensorMap( + keys=tmap.keys, + blocks=transformed_blocks, + ) + + +@conversions.converter(Formats.TENSORMAP, Formats.BLOCK_DICT) +def tensormap_to_blockdict(tmap, lattice): + block_dict = {} + + for key in tmap.keys: + block = tmap[key] + + block_values = block.values + samples = block.samples.values + + for i_pair, ( + i_system, + first_atom, + second_atom, + *cell_shifts, + ) in enumerate(samples): + cell_index = lattice.isc_off[cell_shifts[0], cell_shifts[1], cell_shifts[2]] + + block_dict[(int(first_atom), int(second_atom), int(cell_index))] = ( + block_values[i_pair].squeeze(-1) + ) + + return block_dict diff --git a/src/metatrain/experimental/graph2mat/utils/dataset.py b/src/metatrain/experimental/graph2mat/utils/dataset.py index edd2c6df7a..a4e5d68a1f 100644 --- a/src/metatrain/experimental/graph2mat/utils/dataset.py +++ b/src/metatrain/experimental/graph2mat/utils/dataset.py @@ -5,20 +5,19 @@ import graph2mat import sisl import torch -from e3nn import o3 from graph2mat import ( - AtomicTableWithEdges, BasisConfiguration, BasisMatrix, MatrixDataProcessor, ) from graph2mat.bindings.torch import TorchBasisMatrixDataset -from graph2mat.core.data.basis import NoBasisAtom, get_change_of_basis +from graph2mat.core.data.basis import NoBasisAtom from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import NeighborListOptions, System, register_autograd_neighbors from metatrain.utils.data import system_to_ase +from .conversions import get_target_converters, transform_tensormap_matrix from .mtt import g2m_labels_to_tensormap @@ -52,44 +51,6 @@ def system_to_config( return BasisConfiguration.from_geometry(geometry, matrix=matrix) -def get_converters_to_spherical( - basis_table: AtomicTableWithEdges, -) -> dict[int, torch.Tensor]: - """Get the converters to spherical harmonics for each basis type in the basis table.""" - - cob, _ = get_change_of_basis(basis_table.basis_convention, "spherical") - - p_change_of_basis = torch.tensor(cob) - d_change_of_basis = o3.Irrep(2, 1).D_from_matrix(p_change_of_basis) - - converters = {} - for point_basis in basis_table.basis: - M = torch.eye(point_basis.basis_size) - - start = 0 - for mul, l, p in point_basis.basis: - if p != (-1) ** l: - raise ValueError( - "Only spherical basis with definite parity are supported." - ) - - for i in range(mul): - end = start + (2 * l + 1) - if l == 1: - M[start:end, start:end] = p_change_of_basis - elif l == 2: - M[start:end, start:end] = d_change_of_basis - else: - M[start:end, start:end] = o3.Irrep(l, 1).D_from_matrix( - p_change_of_basis - ) - start = end - - converters[point_basis.type] = M - - return converters - - def get_graph2mat_transform( graph2mat_processors: dict[str, MatrixDataProcessor], nls_options: dict[str, NeighborListOptions], @@ -109,8 +70,10 @@ def get_graph2mat_transform( converters = {} for target_name in graph2mat_processors: - converters[target_name] = get_converters_to_spherical( - graph2mat_processors[target_name].basis_table + converters[target_name] = get_target_converters( + graph2mat_processors[target_name].basis_table, + in_format=graph2mat_processors[target_name].basis_table.basis_convention, + out_format="spherical", ) def transform( @@ -142,24 +105,13 @@ def transform( block_dict_matrices = defaultdict(dict) - tensormap_matrix = targets[target_name] + tensormap_matrix = transform_tensormap_matrix( + targets[target_name], converters=converters[target_name] + ) for key in tensormap_matrix.keys: block = tensormap_matrix[key] - block_values = block.values - first_atom_type, second_atom_type = key - conv_left = converters[target_name][int(first_atom_type)].to( - block_values.dtype - ) - conv_right = converters[target_name][int(second_atom_type)].to( - block_values.dtype - ) - - block_values = torch.einsum( - "ij, bjkp, kl -> bilp", conv_left, block_values, conv_right.T - ) - samples = block.samples.values for i_pair, ( @@ -233,3 +185,68 @@ def transform( return systems, targets, extra return transform + + +def get_graph2mat_eval_transform( + graph2mat_processors: dict[str, MatrixDataProcessor], + nls_options: dict[str, NeighborListOptions], + outputs: Optional[list[str]] = None, +) -> Callable: + """Same as `get_graph2mat_transform`, but for evaluation.""" + converters = {} + for target_name in graph2mat_processors: + converters[target_name] = get_target_converters( + graph2mat_processors[target_name].basis_table, + in_format=graph2mat_processors[target_name].basis_table.basis_convention, + out_format="spherical", + ) + + def transform( + systems: list[System], + targets: dict[str, TensorMap], + extra: dict[str, TensorMap], + ) -> tuple[list[System], dict[str, TensorMap], dict[str, TensorMap]]: + for target_name in outputs: + configs = [ + system_to_config(system, graph2mat_processors[target_name], None) + for system in systems + ] + + dataset = TorchBasisMatrixDataset( + configs, + data_processor=graph2mat_processors[target_name], + data_cls=graph2mat.bindings.torch.TorchBasisMatrixData, + load_labels=False, + ) + + for i, data in enumerate(dataset): + edge_index = data.edge_index + distances = data.positions[edge_index].diff(dim=0) + + cell_shifts = sisl.Lattice( + data.cell.numpy(), data.nsc.reshape(3).numpy() + ).sc_off[data.neigh_isc] + + neighbor_list = TensorBlock( + values=distances.reshape(-1, 3, 1).to(systems[i].positions.dtype), + samples=Labels( + names=[ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + values=torch.hstack([edge_index.T, torch.tensor(cell_shifts)]), + assume_unique=True, + ), + components=[Labels.range("xyz", 3)], + properties=Labels.range("distance", 1), + ) + + register_autograd_neighbors(systems[i], neighbor_list) + systems[i].add_neighbor_list(nls_options[target_name], neighbor_list) + + return systems, targets, extra + + return transform diff --git a/src/metatrain/experimental/graph2mat/utils/eval.py b/src/metatrain/experimental/graph2mat/utils/eval.py new file mode 100644 index 0000000000..1b3e515f86 --- /dev/null +++ b/src/metatrain/experimental/graph2mat/utils/eval.py @@ -0,0 +1,159 @@ +"""Little helper script to evaluate the model from a ckpt file while +it is not torchscript compatible yet.""" + +import argparse + +import graph2mat +import sisl +import torch +from graph2mat.bindings.torch import TorchBasisMatrixData, TorchBasisMatrixDataset +from metatomic.torch import ModelOutput + +from metatrain.experimental.graph2mat import MetaGraph2Mat +from metatrain.experimental.graph2mat.utils.conversions import ( + get_target_converters, + transform_tensormap_matrix, +) +from metatrain.experimental.graph2mat.utils.dataset import ( + get_graph2mat_eval_transform, + system_to_config, +) +from metatrain.utils.data import CollateFn, Dataset, read_systems, unpack_batch +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.transfer import batch_to + + +# ---------------------------------------- +# Argument parsing +# ---------------------------------------- + +parser = argparse.ArgumentParser(description="Evaluate Graph2Mat model from checkpoint") +parser.add_argument( + "input_file", type=str, help="Input file containing the systems (e.g., XYZ format)" +) +parser.add_argument("model_ckpt", type=str, help="Path to the model checkpoint file") +parser.add_argument( + "--targets", + nargs="+", + default=["density_matrix"], + help="List of target properties to evaluate", +) +args = parser.parse_args() + +# ---------------------------------------- +# Reading input data and preparing model +# ---------------------------------------- +systems = read_systems( + filename=args.input_file, + reader="ase", +) +targets = {target: ModelOutput() for target in args.targets} +dataset = Dataset.from_dict({"system": systems}) + +ckpt = torch.load(args.model_ckpt, map_location="cpu") +model = MetaGraph2Mat.load_checkpoint(ckpt, context="export") + +requested_neighbor_lists = get_requested_neighbor_lists(model.featurizer_model) +collate_fn = CollateFn( + list(targets), + callables=[ + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_graph2mat_eval_transform( + model.graph2mat_processors, model.graph2mat_nls, outputs=list(targets) + ), + ], +) +dataloader = torch.utils.data.DataLoader( + dataset, batch_size=1, collate_fn=collate_fn, shuffle=False +) + +# --------------------------------------------------------------- +# Helpers to convert from spherical harmonics to the basis used +# in the target data. +# --------------------------------------------------------------- + +converters = {} +for target_name in targets: + converters[target_name] = get_target_converters( + model.graph2mat_processors[target_name].basis_table, + in_format="spherical", + out_format=model.graph2mat_processors[target_name].basis_table.basis_convention, + ) + + +def spherical_to_basis( + data: TorchBasisMatrixData, + converters: dict, + data_processor: graph2mat.MatrixDataProcessor, +): + """The metatrain graph2mat model predicts the matrices in spherical harmonics basis. + However, the target might be in a slightly different convention + (e.g. Y-ZX instead of YZX). + + This function (inefficiently) converts the predicted data into the right basis convention. + """ + dm = graph2mat.conversions.torch_basismatrixdata_to_sisl_DM(data) + config = graph2mat.conversions.sisl_to_orbitalconfiguration(dm) + tmap = graph2mat.conversions.basisconfiguration_to_tensormap(config) + tmap = transform_tensormap_matrix(tmap, converters=converters) + converted_bdict = graph2mat.conversions.tensormap_to_block_dict( + tmap, lattice=sisl.Lattice(config.cell, nsc=config.matrix.nsc) + ) + config.matrix.block_dict = converted_bdict + data = graph2mat.conversions.orbitalconfiguration_to_basismatrixdata( + config, data_processor + ) + return data + + +# ---------------------------------------- +# Evaluation loop +# ---------------------------------------- + +for batch in dataloader: + systems, batch_targets, batch_extra_data = unpack_batch(batch) + systems, batch_targets, batch_extra_data = batch_to( + systems, batch_targets, batch_extra_data, dtype=torch.float32, device="cpu" + ) + + out = model(systems, outputs=targets) + + for target in args.targets: + dm_tensormap = out[target] + + configs = [ + system_to_config(system, model.graph2mat_processors[target], None) + for system in systems + ] + + dataset = TorchBasisMatrixDataset( + configs, + data_processor=model.graph2mat_processors[target], + data_cls=graph2mat.bindings.torch.TorchBasisMatrixData, + load_labels=False, + ) + data = dataset[0] + + data["point_labels"] = dm_tensormap.block(0).values.ravel() + data["edge_labels"] = dm_tensormap.block(1).values.ravel() + + data = spherical_to_basis( + data, + converters=converters[target], + data_processor=model.graph2mat_processors[target], + ) + + if target == "density_matrix": + dm = graph2mat.conversions.torch_basismatrixdata_to_sisl_DM(data) + dm.write("prediction.DM") + elif target == "hamiltonian": + hamiltonian = graph2mat.conversions.torch_basismatrixdata_to_sisl_H(data) + hamiltonian.write("prediction.TSHS") + elif target == "overlap_matrix": + overlap_matrix = graph2mat.conversions.torch_basismatrixdata_to_sisl_S(data) + overlap_matrix.write("prediction.TSHS") + else: + print(f"Writing for target {target} not implemented.")