diff --git a/.github/workflows/architecture-tests.yml b/.github/workflows/architecture-tests.yml index b9fcc68686..ed55f27c84 100644 --- a/.github/workflows/architecture-tests.yml +++ b/.github/workflows/architecture-tests.yml @@ -19,6 +19,7 @@ jobs: - llpr - mace - nanopet + - dpa3 - pet - soap-bpnn diff --git a/CODEOWNERS b/CODEOWNERS index 239b7ce203..7aec4ce689 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -6,6 +6,7 @@ **/pet @abmazitov **/gap @DavideTisi **/nanopet @frostedoyster +**/dpa3 @wentaoli **/llpr @frostedoyster @SanggyuChong **/flashmd @johannes-spies @frostedoyster **/classifier @frostedoyster diff --git a/README.md b/README.md index 296a470454..2225cfe64f 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ model (sorted by alphabetic order): | Name | Description | |--------------------------|--------------------------------------------------------------------------------------------------------------------------------------| +| DPA3 | An invariant graph neural network based on line graph series representations | | FlashMD | An architecture for the direct prediction of molecular dynamics | | GAP | Sparse Gaussian Approximation Potential (GAP) using Smooth Overlap of Atomic Positions (SOAP). | | MACE | A higher order equivariant message passing neural network. | diff --git a/examples/1-advanced/03-fitting-generic-targets.py b/examples/1-advanced/03-fitting-generic-targets.py index ddda39e41d..9616a38b55 100644 --- a/examples/1-advanced/03-fitting-generic-targets.py +++ b/examples/1-advanced/03-fitting-generic-targets.py @@ -46,6 +46,11 @@ - Yes - Yes - Only with ``rank=1`` (vectors) + * - DPA3 + - Energy, forces, virial + - Yes + - No + - No Preparing generic targets for reading by metatrain -------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index c36d5a3b1e..12fb5ac57c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,9 @@ requires = [ build-backend = "setuptools.build_meta" [project.optional-dependencies] +dpa3 = [ + "deepmd-kit[torch]>=3.1.0", +] soap-bpnn = [ "torch-spex>=0.1,<0.2", "wigners", diff --git a/src/metatrain/experimental/dpa3/__init__.py b/src/metatrain/experimental/dpa3/__init__.py new file mode 100644 index 0000000000..66be2be6cb --- /dev/null +++ b/src/metatrain/experimental/dpa3/__init__.py @@ -0,0 +1,15 @@ +from .model import DPA3 +from .trainer import Trainer + + +__model__ = DPA3 +__trainer__ = Trainer + +__authors__ = [ + ("Duo Zhang ", "@duozhang"), +] + +__maintainers__ = [ + ("Duo Zhang ", "@duozhang"), + ("Wentao Li ", "@wentaoli"), +] diff --git a/src/metatrain/experimental/dpa3/checkpoints.py b/src/metatrain/experimental/dpa3/checkpoints.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/metatrain/experimental/dpa3/documentation.py b/src/metatrain/experimental/dpa3/documentation.py new file mode 100644 index 0000000000..2d897b65c9 --- /dev/null +++ b/src/metatrain/experimental/dpa3/documentation.py @@ -0,0 +1,149 @@ +""" +DPA3 (experimental) +====================== + +This is an interface to the DPA3 architecture described in https://arxiv.org/abs/2506.01686 +and implemented in deepmd-kit (https://github.com/deepmodeling/deepmd-kit). +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.loss import LossSpecification + + +########################### +# MODEL HYPERPARAMETERS # +########################### + + +class RepflowHypers(TypedDict): + n_dim: int = 128 + e_dim: int = 64 + a_dim: int = 32 + nlayers: int = 6 + e_rcut: float = 6.0 + e_rcut_smth: float = 5.3 + e_sel: int = 1200 + a_rcut: float = 4.0 + a_rcut_smth: float = 3.5 + a_sel: int = 300 + axis_neuron: int = 4 + skip_stat: bool = True + a_compress_rate: int = 1 + a_compress_e_rate: int = 2 + a_compress_use_split: bool = True + update_angle: bool = True + # TODO: what are the options here + update_style: str = "res_residual" + update_residual: float = 0.1 + # TODO: what are the options here + update_residual_init: str = "const" + smooth_edge_update: bool = True + use_dynamic_sel: bool = True + sel_reduce_factor: float = 10.0 + + +class DescriptorHypers(TypedDict): + # TODO: what are the options here + type: str = "dpa3" + repflow: RepflowHypers = init_with_defaults(RepflowHypers) + # TODO: what are the options here + activation_function: str = "custom_silu:10.0" + use_tebd_bias: bool = False + # TODO: what are the options here + precision: str = "float32" + concat_output_tebd: bool = False + + +class FittingNetHypers(TypedDict): + neuron: list[int] = [240, 240, 240] + resnet_dt: bool = True + seed: int = 1 + # TODO: what are the options here + precision: str = "float32" + # TODO: what are the options here + activation_function: str = "custom_silu:10.0" + # TODO: what are the options here + type: str = "ener" + numb_fparam: int = 0 + numb_aparam: int = 0 + dim_case_embd: int = 0 + trainable: bool = True + rcond: Optional[float] = None + atom_ener: list[float] = [] + use_aparam_as_mask: bool = False + + +class ModelHypers(TypedDict): + """Hyperparameters for the DPA3 model.""" + + type_map: list[str] = ["H", "C", "N", "O"] + + descriptor: DescriptorHypers = init_with_defaults(DescriptorHypers) + fitting_net: FittingNetHypers = init_with_defaults(FittingNetHypers) + + +############################## +# TRAINER HYPERPARAMETERS # +############################## + + +class TrainerHypers(TypedDict): + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for DDP communication""" + batch_size: int = 8 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 100 + """Number of epochs.""" + learning_rate: float = 0.001 + """Learning rate.""" + + # TODO: update the scheduler or not + scheduler_patience: int = 100 + scheduler_factor: float = 0.8 + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_composition_weights: FixedCompositionWeights = {} + """Weights for atomic contributions. + + This is passed to the ``fixed_weights`` argument of + :meth:`CompositionModel.train_model + `, + see its documentation to understand exactly what to pass here. + """ + # fixed_scaling_weights: FixedScalerWeights = {} + # """Weights for target scaling. + + # This is passed to the ``fixed_weights`` argument of + # :meth:`Scaler.train_model `, + # see its documentation to understand exactly what to pass here. + # """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + # num_workers: Optional[int] = None + # """Number of workers for data loading. If not provided, it is set + # automatically.""" + log_mae: bool = False + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "rmse_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + + loss: str | dict[str, LossSpecification] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" diff --git a/src/metatrain/experimental/dpa3/model.py b/src/metatrain/experimental/dpa3/model.py new file mode 100644 index 0000000000..d593a6ca88 --- /dev/null +++ b/src/metatrain/experimental/dpa3/model.py @@ -0,0 +1,446 @@ +import logging +from typing import Any, Dict, List, Literal, Optional + +import metatensor.torch as mts +import torch +from deepmd.pt.model.model import get_standard_model +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import CompositionModel +from metatrain.utils.data import TargetInfo +from metatrain.utils.data.dataset import DatasetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from . import checkpoints +from .documentation import ModelHypers + + +# Data processing +def concatenate_structures(systems: List[System]): + device = systems[0].positions.device + positions = [] + species = [] + cells = [] + atom_nums: List[int] = [] + node_counter = 0 + + atom_index_list: List[torch.Tensor] = [] + system_index_list: List[torch.Tensor] = [] + + for i, system in enumerate(systems): + atom_nums.append(len(system.positions)) + atom_index_list.append(torch.arange(start=0, end=len(system.positions))) + system_index_list.append(torch.full((len(system.positions),), i)) + max_atom_num = max(atom_nums) + atom_index = torch.cat(atom_index_list, dim=0).to(torch.int32).to(device) + system_index = torch.cat(system_index_list, dim=0).to(torch.int32).to(device) + + positions = torch.zeros( + (len(systems), max_atom_num, 3), dtype=systems[0].positions.dtype + ) + species = torch.full((len(systems), max_atom_num), -1, dtype=systems[0].types.dtype) + cells = torch.stack( + [system.cell for system in systems] + ) # 形状为 [batch_size, 3, 3] 或相应的晶胞形状 + + for i, system in enumerate(systems): + positions[i, : len(system.positions)] = system.positions + species[i, : len(system.positions)] = system.types + cells[i] = system.cell + node_counter += len(system.positions) + + return ( + positions.to(device), + species.to(device), + cells.to(device), + atom_index, + system_index, + ) + + +# Model definition +class DPA3(ModelInterface[ModelHypers]): + __checkpoint_version__ = 1 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float32, torch.float64] + __default_metadata__ = ModelMetadata( + references={ + "implementation": [ + "https://github.com/deepmodeling/deepmd-kit", + ], + "architecture": [ + "DPA3: https://arxiv.org/abs/2506.01686", + ], + } + ) + + component_labels: Dict[str, List[List[Labels]]] # torchscript needs this + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + self.atomic_types = dataset_info.atomic_types + self.dtype = self.hypers["descriptor"]["precision"] + + if self.dtype == "float64": + self.dtype = torch.float64 + elif self.dtype == "float32": + self.dtype = torch.float32 + else: + raise ValueError(f"Unsupported precision: {self.dtype}") + + self.requested_nl = NeighborListOptions( + cutoff=self.hypers["descriptor"]["repflow"]["e_rcut"], + full_list=True, + strict=True, + ) + self.targets_keys = list(dataset_info.targets.keys())[0] + + self.model = get_standard_model(hypers) + + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + self.outputs: Dict[str, ModelOutput] = {} + self.single_label = Labels.single() + + self.num_properties: Dict[str, Dict[str, int]] = {} + + self.key_labels: Dict[str, Labels] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.property_labels: Dict[str, List[Labels]] = {} + for target_name, target in dataset_info.targets.items(): + self._add_output(target_name, target) + + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + self.additive_models = torch.nn.ModuleList(additive_models) + + def _add_output(self, target_name: str, target: TargetInfo) -> None: + if not target.is_scalar: + raise ValueError("The DPA3 architecture can only predict scalars.") + self.num_properties[target_name] = {} + self.key_labels[target_name] = target.layout.keys + self.component_labels[target_name] = [ + block.components for block in target.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target.layout.blocks() + ] + self.outputs[target_name] = ModelOutput( + quantity=target.quantity, + unit=target.unit, + per_atom=True, + ) + + def get_rcut(self): + return self.model.atomic_model.get_rcut() + + def get_sel(self): + return self.model.atomic_model.get_sel() + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + device = systems[0].positions.device + + atype_dtype = systems[0].types.dtype + + if self.single_label.values.device != device: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + + return_dict: Dict[str, TensorMap] = {} + + (positions, species, cells, atom_index, system_index) = concatenate_structures( + systems + ) + + type_to_index = { + atomic_type: idx for idx, atomic_type in enumerate(self.atomic_types) + } + type_to_index[-1] = -1 + + atype = torch.tensor( + [[type_to_index[s.item()] for s in row] for row in species], + dtype=atype_dtype, + ).to(positions.device) + atype = atype.to(atype_dtype) + + if torch.all(cells == 0).item(): + box = None + else: + box = cells + + model_ret = self.model.forward_common( + positions, + atype, + box, + fparam=None, + aparam=None, + do_atomic_virial=False, + ) + + if self.model.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.model.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + + else: + model_predict["force"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + else: + model_predict = model_ret + model_predict["updated_coord"] += positions + + atomic_properties: Dict[str, TensorMap] = {} + blocks: List[TensorBlock] = [] + + system_col = system_index + atom_col = atom_index + + values = torch.stack([system_col, atom_col], dim=0).transpose(0, 1) + invariant_coefficients = Labels( + names=["system", "atom"], values=values.to(device) + ) + + mask = torch.abs(model_predict["atom_energy"]) > 1e-10 + atomic_property_tensor = model_predict["atom_energy"][mask].unsqueeze(-1) + + blocks.append( + TensorBlock( + values=atomic_property_tensor, + samples=invariant_coefficients, + components=self.component_labels[self.targets_keys][0], + properties=self.property_labels[self.targets_keys][0].to(device), + ) + ) + + atomic_properties[self.targets_keys] = TensorMap( + self.key_labels[self.targets_keys].to(device), blocks + ) + + if selected_atoms is not None: + for output_name, tmap in atomic_properties.items(): + atomic_properties[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + + for output_name, atomic_property in atomic_properties.items(): + if outputs[output_name].per_atom: + return_dict[output_name] = atomic_property + else: + # sum the atomic property to get the total property + return_dict[output_name] = sum_over_atoms(atomic_property) + + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler(systems, return_dict) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + return_dict[name] = mts.add( + return_dict[name], + additive_contributions[name], + ) + + return return_dict + + def restart(self, dataset_info: DatasetInfo) -> "DPA3": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The DPA3 model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler.restart(dataset_info) + + return self + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "DPA3": + model_data = checkpoint["model_data"] + + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + if model_state_dict is None: + model_state_dict = checkpoint["model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + dtype = next(iter(model_state_dict.values())).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + + # Loading the metadata from the checkpoint + metadata = checkpoint.get("metadata", None) + if metadata is not None: + model.__default_metadata__ = metadata + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {self.dtype} for DPA3") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This funciton moves them: + + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.hypers["descriptor"]["repflow"]["e_rcut"]] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + if metadata is None: + metadata = self.__default_metadata__ + else: + metadata = merge_metadata(self.__default_metadata__, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint + + def get_checkpoint(self) -> Dict: + checkpoint = { + "architecture_name": "experimental.dpa3", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": self.state_dict(), + "best_model_state_dict": None, + } + return checkpoint + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs diff --git a/src/metatrain/experimental/dpa3/tests/__init__.py b/src/metatrain/experimental/dpa3/tests/__init__.py new file mode 100644 index 0000000000..f9f29d3a06 --- /dev/null +++ b/src/metatrain/experimental/dpa3/tests/__init__.py @@ -0,0 +1,12 @@ +from pathlib import Path + +from metatrain.utils.architectures import get_default_hypers + + +DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz") +DATASET_WITH_FORCES_PATH = str( + Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz" +) + +DEFAULT_HYPERS = get_default_hypers("experimental.dpa3") +MODEL_HYPERS = DEFAULT_HYPERS["model"] diff --git a/src/metatrain/experimental/dpa3/tests/checkpoints/model-v1_trainer-v1.ckpt.gz b/src/metatrain/experimental/dpa3/tests/checkpoints/model-v1_trainer-v1.ckpt.gz new file mode 100644 index 0000000000..fcce107c73 Binary files /dev/null and b/src/metatrain/experimental/dpa3/tests/checkpoints/model-v1_trainer-v1.ckpt.gz differ diff --git a/src/metatrain/experimental/dpa3/tests/test_basic.py b/src/metatrain/experimental/dpa3/tests/test_basic.py new file mode 100644 index 0000000000..acc6f51dd2 --- /dev/null +++ b/src/metatrain/experimental/dpa3/tests/test_basic.py @@ -0,0 +1,67 @@ +import copy + +import pytest + +from metatrain.utils.architectures import get_default_hypers +from metatrain.utils.testing import ( + ArchitectureTests, + AutogradTests, + CheckpointTests, + ExportedTests, + InputTests, + OutputTests, + TorchscriptTests, + TrainingTests, +) + + +class DPA3Tests(ArchitectureTests): + architecture = "experimental.dpa3" + + @pytest.fixture + def minimal_model_hypers(self) -> dict: + """Minimal hyperparameters for a DPA3 model for the smallest + checkpoint possible. + + :return: Hyperparameters for the model. + """ + hypers = copy.deepcopy(get_default_hypers(self.architecture)["model"]) + hypers["descriptor"]["repflow"]["n_dim"] = 2 + hypers["descriptor"]["repflow"]["e_dim"] = 2 + hypers["descriptor"]["repflow"]["a_dim"] = 2 + hypers["descriptor"]["repflow"]["e_sel"] = 1 + hypers["descriptor"]["repflow"]["a_sel"] = 1 + hypers["descriptor"]["repflow"]["axis_neuron"] = 1 + hypers["descriptor"]["repflow"]["nlayers"] = 1 + hypers["fitting_net"]["neuron"] = [1, 1] + return hypers + + +class TestInput(InputTests, DPA3Tests): ... + + +class TestOutput(OutputTests, DPA3Tests): + supports_multiscalar_outputs = False + supports_spherical_outputs = False + supports_vector_outputs = False + supports_features = False + supports_last_layer_features = False + + +class TestAutograd(AutogradTests, DPA3Tests): + cuda_nondet_tolerance = 1e-12 + + +class TestTorchscript(TorchscriptTests, DPA3Tests): + float_hypers = ["descriptor.repflow.e_rcut", "descriptor.repflow.e_rcut_smth"] + supports_spherical_outputs = False + + +class TestExported(ExportedTests, DPA3Tests): ... + + +class TestTraining(TrainingTests, DPA3Tests): ... + + +class TestCheckpoints(CheckpointTests, DPA3Tests): + incompatible_trainer_checkpoints = [] diff --git a/src/metatrain/experimental/dpa3/tests/test_regression.py b/src/metatrain/experimental/dpa3/tests/test_regression.py new file mode 100644 index 0000000000..20fb0ff6c7 --- /dev/null +++ b/src/metatrain/experimental/dpa3/tests/test_regression.py @@ -0,0 +1,143 @@ +import random + +import numpy as np +import torch +from metatomic.torch import ModelOutput +from omegaconf import OmegaConf + +from metatrain.experimental.dpa3 import DPA3, Trainer +from metatrain.utils.data import Dataset, DatasetInfo +from metatrain.utils.data.readers import ( + read_systems, + read_targets, +) +from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + +from . import DATASET_PATH, DATASET_WITH_FORCES_PATH, DEFAULT_HYPERS, MODEL_HYPERS + + +def test_regression_init(): + """Regression test for the model at initialization""" + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + + targets = {} + targets["mtt::U0"] = get_energy_target_info({"quantity": "energy", "unit": "eV"}) + + dataset_info = DatasetInfo( + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets + ) + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + + # Predict on the first five systems + systems = read_systems(DATASET_PATH)[:5] + systems = [system.to(torch.float64) for system in systems] + for system in systems: + get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + output = model( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + ) + + expected_output = torch.tensor( + [ + [8.893970727921], + [7.150644659996], + [5.338875532150], + [7.145487308502], + [5.402073264122], + ], + dtype=torch.float64, + ) + + # if you need to change the hardcoded values: + torch.set_printoptions(precision=12) + print(output["mtt::U0"].block().values) + + torch.testing.assert_close(output["mtt::U0"].block().values, expected_output) + + +def test_regression_energies_forces_train(): + """Regression test for the model when trained for 2 epoch on a small dataset""" + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + + systems = read_systems(DATASET_WITH_FORCES_PATH) + + conf = { + "energy": { + "quantity": "energy", + "read_from": DATASET_WITH_FORCES_PATH, + "reader": "ase", + "key": "energy", + "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, + "forces": {"read_from": DATASET_WITH_FORCES_PATH, "key": "force"}, + "stress": False, + "virial": False, + } + } + + targets, target_info_dict = read_targets(OmegaConf.create(conf)) + targets = {"energy": targets["energy"]} + dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]}) + hypers = DEFAULT_HYPERS.copy() + hypers["training"]["num_epochs"] = 1 + hypers["training"]["scheduler_patience"] = 1 + hypers["training"]["fixed_composition_weights"] = {} + + dataset_info = DatasetInfo( + length_unit="Angstrom", atomic_types=[6], targets=target_info_dict + ) + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + trainer = Trainer(hypers["training"]) + trainer.train( + model=model, + dtype=torch.float32, + devices=[torch.device("cpu")], + train_datasets=[dataset], + val_datasets=[dataset], + checkpoint_dir=".", + ) + + # Predict on the first five systems + systems = [system.to(torch.float64) for system in systems] + for system in systems: + get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + output = evaluate_model( + model, systems[:5], targets=target_info_dict, is_training=False + ) + + expected_output = torch.tensor( + [ + [0.630174279213], + [0.653932452202], + [0.664113998413], + [0.590713620186], + [0.635889530182], + ], + dtype=torch.float64, + ) + + expected_gradients_output = torch.tensor( + [0.006374867036, -0.008849388247, 0.030855362978], dtype=torch.float64 + ) + + # if you need to change the hardcoded values: + torch.set_printoptions(precision=12) + print(output["energy"].block().values) + print(output["energy"].block().gradient("positions").values.squeeze(-1)[0]) + + torch.testing.assert_close(output["energy"].block().values, expected_output) + torch.testing.assert_close( + output["energy"].block().gradient("positions").values[0, :, 0], + expected_gradients_output, + ) diff --git a/src/metatrain/experimental/dpa3/trainer.py b/src/metatrain/experimental/dpa3/trainer.py new file mode 100644 index 0000000000..848b78b807 --- /dev/null +++ b/src/metatrain/experimental/dpa3/trainer.py @@ -0,0 +1,563 @@ +import copy +import logging +from pathlib import Path +from typing import Any, Dict, List, Literal, Union, cast + +import torch +import torch.distributed +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import TrainerInterface +from metatrain.utils.additive import remove_additive +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + _is_disk_dataset, + unpack_batch, +) +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import remove_scale +from metatrain.utils.transfer import ( + batch_to, +) + +from .documentation import TrainerHypers +from .model import DPA3 + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 1 + + def __init__(self, hypers: TrainerHypers): + super().__init__(hypers) + + self.optimizer_state_dict = None + self.scheduler_state_dict = None + self.epoch: int | None = None + self.best_epoch: int | None = None + self.best_metric: float | None = None + self.best_model_state_dict = None + self.best_optimizer_state_dict = None + + def train( + self, + model: DPA3, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ): + assert dtype in DPA3.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + + if is_distributed: + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + torch.distributed.init_process_group(backend="nccl") + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with DPA3, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + else: + device = devices[ + 0 + ] # only one device, as we don't support multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Calculate the neighbor lists in advance (in particular, this + # needs to happen before the additive models are trained, as they + # might need them): + logging.info("Calculating neighbor lists for the datasets") + requested_neighbor_lists = get_requested_neighbor_lists(model) + for dataset in train_datasets + val_datasets: + # If the dataset is a disk dataset, the NLs are already attached, we will + # just check the first system + if _is_disk_dataset(dataset): + system = dataset[0]["system"] + for options in requested_neighbor_lists: + if options not in system.known_neighbor_lists(): + raise ValueError( + "The requested neighbor lists are not attached to the " + f"system. Neighbor list {options} is missing from the " + "first system in the disk dataset. Make sure you save " + "the neighbor lists in the systems when saving the dataset." + ) + else: + for sample in dataset: + system = sample["system"] + # The following line attaches the neighbors lists to the system, + # and doesn't require to reassign the system to the dataset: + get_system_with_neighbor_lists(system, requested_neighbor_lists) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of the SOAP-BPNN are always in float64 (to avoid + # numerical errors in the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_composition_weights"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + # TODO: fixed_scaling_weights + ) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Create a collate function: + targets_keys = list( + (model.module if is_distributed else model).dataset_info.targets.keys() + ) + collate_fn = CollateFn(target_keys=targets_keys) + + # Create dataloader for the training datasets: + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + if len(val_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A validation dataset has fewer samples " + f"({len(val_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + # Extract all the possible outputs and their gradients: + train_targets = (model.module if is_distributed else model).dataset_info.targets + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator( + targets=train_targets, + config=loss_hypers, + ) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + # Create an optimizer: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + if self.optimizer_state_dict is not None: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a scheduler: + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=self.hypers["scheduler_factor"], + patience=self.hypers["scheduler_patience"], + threshold=0.001, + min_lr=1e-5, + ) + if self.scheduler_state_dict is not None: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + # per-atom targets: + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + old_lr = optimizer.param_groups[0]["lr"] + logging.info(f"Initial learning rate: {old_lr}") + + start_epoch = 0 if self.epoch is None else self.epoch + 1 + + # Train the model: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Starting training") + epoch = start_epoch + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + + for batch in train_dataloader: + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, device=device + ) + for additive_model in ( + model.module if is_distributed else model + ).additive_models: + targets = remove_additive( + systems, targets, additive_model, train_targets + ) + targets = remove_scale( + systems, targets, (model.module if is_distributed else model).scaler + ) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype + ) + + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + + train_loss_batch = loss_fn(predictions, targets, extra_data) + + train_loss_batch.backward() + optimizer.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + train_rmse_calculator.update(predictions, targets) + if self.hypers["log_mae"]: + train_mae_calculator.update(predictions, targets) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, device=device + ) + for additive_model in ( + model.module if is_distributed else model + ).additive_models: + targets = remove_additive( + systems, targets, additive_model, train_targets + ) + targets = remove_scale( + systems, targets, (model.module if is_distributed else model).scaler + ) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype + ) + + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + val_rmse_calculator.update(predictions, targets) + if self.hypers["log_mae"]: + val_mae_calculator.update(predictions, targets) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = {"loss": train_loss, **finalized_train_info} + finalized_val_info = {"loss": val_loss, **finalized_val_info} + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + ) + + lr_scheduler.step(val_loss) + new_lr = lr_scheduler.get_last_lr()[0] + if new_lr != old_lr: + if new_lr < 1e-7: + logging.info("Learning rate is too small, stopping training") + break + else: + logging.info(f"Changing learning rate from {old_lr} to {new_lr}") + old_lr = new_lr + # load best model and optimizer state dict, re-initialize scheduler + (model.module if is_distributed else model).load_state_dict( + self.best_model_state_dict + ) + optimizer.load_state_dict(self.best_optimizer_state_dict) + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=self.hypers["scheduler_factor"], + patience=self.hypers["scheduler_patience"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + checkpoint = model.get_checkpoint() + checkpoint.update( + { + "train_hypers": self.hypers, + "trainer_ckpt_version": self.__checkpoint_version__, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model, path: Union[str, Path]): + checkpoint = model.get_checkpoint() + checkpoint.update( + { + "train_hypers": self.hypers, + "trainer_ckpt_version": self.__checkpoint_version__, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], # not used at the moment + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using trainer " + f"version {checkpoint['trainer_ckpt_version']}, while the current " + f"trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/utils/testing/architectures.py b/src/metatrain/utils/testing/architectures.py index 4d2dfd008b..d6edd27076 100644 --- a/src/metatrain/utils/testing/architectures.py +++ b/src/metatrain/utils/testing/architectures.py @@ -145,11 +145,22 @@ def per_atom(self, request: pytest.FixtureRequest) -> bool: """ return request.param + @pytest.fixture(params=[1, 5]) + def num_subtargets(self, request: pytest.FixtureRequest) -> int: + """Fixture to provide different numbers of subtargets for + testing. + + :param request: The pytest request fixture. + :return: The number of subtargets. + """ + return request.param + @pytest.fixture - def dataset_info_scalar(self, per_atom: bool) -> DatasetInfo: + def dataset_info_scalar(self, num_subtargets: int, per_atom: bool) -> DatasetInfo: """Fixture that provides a basic ``DatasetInfo`` with a scalar target for testing. + :param num_subtargets: The number of scalars in the target. :param per_atom: Whether the target is per-atom or not. :return: A ``DatasetInfo`` instance with a scalar target. """ @@ -163,7 +174,7 @@ def dataset_info_scalar(self, per_atom: bool) -> DatasetInfo: "quantity": "scalar", "unit": "", "type": "scalar", - "num_subtargets": 5, + "num_subtargets": num_subtargets, "per_atom": per_atom, }, ) diff --git a/src/metatrain/utils/testing/output.py b/src/metatrain/utils/testing/output.py index 929d2e3d15..a3d198c16b 100644 --- a/src/metatrain/utils/testing/output.py +++ b/src/metatrain/utils/testing/output.py @@ -36,6 +36,8 @@ class OutputTests(ArchitectureTests): supports_scalar_outputs: bool = True """Whether the model supports scalar outputs.""" + supports_multiscalar_outputs: bool = True + """Whether the model supports outputs with multiple scalar subtargets.""" supports_vector_outputs: bool = True """Whether the model supports vector outputs.""" supports_spherical_outputs: bool = True @@ -132,7 +134,11 @@ def _get_output( return model([system], {k: ModelOutput(per_atom=per_atom) for k in outputs}) def test_output_scalar( - self, model_hypers: dict, dataset_info_scalar: DatasetInfo, per_atom: bool + self, + model_hypers: dict, + dataset_info_scalar: DatasetInfo, + num_subtargets: int, + per_atom: bool, ) -> None: """Tests that forward pass works for scalar outputs. @@ -140,7 +146,11 @@ def test_output_scalar( and values shape. This test is skipped if the model does not support scalar outputs, - i.e., if ``supports_scalar_outputs`` is set to ``False``. + i.e., if ``supports_scalar_outputs`` is set to ``False``. The test + is run twice, once for single scalar outputs, and once for a scalar + output with multiple subtargets. If the model does not support + multiple scalar subtargets, set ``supports_multiscalar_outputs`` + to ``False``, which will skip the test for multiple subtargets. If this test is failing, your model might: @@ -150,23 +160,29 @@ def test_output_scalar( :param model_hypers: Hyperparameters to initialize the model. :param dataset_info_scalar: Dataset information with scalar outputs. + :param num_subtargets: The number of scalars that the target contains. :param per_atom: Whether the requested outputs are per-atom or not. """ if not self.supports_scalar_outputs: pytest.skip(f"{self.architecture} does not support scalar outputs.") + if num_subtargets > 1 and not self.supports_multiscalar_outputs: + pytest.skip( + f"{self.architecture} does not support multiple scalar subtargets." + ) + outputs = self._get_output( model_hypers, dataset_info_scalar, per_atom, ["scalar"] ) if per_atom: assert outputs["scalar"].block().samples.names == ["system", "atom"] - assert outputs["scalar"].block().values.shape == (4, 5) + assert outputs["scalar"].block().values.shape == (4, num_subtargets) else: assert outputs["scalar"].block().samples.names == ["system"], ( outputs["scalar"].block().samples.names ) - assert outputs["scalar"].block().values.shape == (1, 5) + assert outputs["scalar"].block().values.shape == (1, num_subtargets) def test_output_vector( self, model_hypers: dict, dataset_info_vector: DatasetInfo, per_atom: bool diff --git a/src/metatrain/utils/testing/torchscript.py b/src/metatrain/utils/testing/torchscript.py index 6703ce336f..b4ebcc9746 100644 --- a/src/metatrain/utils/testing/torchscript.py +++ b/src/metatrain/utils/testing/torchscript.py @@ -1,6 +1,7 @@ import copy from typing import Any +import pytest import torch from metatomic.torch import System @@ -20,6 +21,9 @@ class TorchscriptTests(ArchitectureTests): that are floats. A test will set these to integers to test that TorchScript compilation works in that case.""" + supports_spherical_outputs: bool = True + """Whether the model supports spherical tensor outputs.""" + def jit_compile(self, model: ModelInterface) -> torch.jit.ScriptModule: """JIT compiles the given model. @@ -109,11 +113,17 @@ def test_torchscript_spherical( ) -> None: """Tests that there is no problem with jitting with spherical targets. + This test is skipped if the model does not support spherical outputs, + i.e., if ``supports_spherical_outputs`` is set to ``False``. + :param model_hypers: Hyperparameters to initialize the model. :param dataset_info_spherical: Dataset to initialize the model (containing spherical targets). """ + if not self.supports_spherical_outputs: + pytest.skip("Model does not support spherical outputs.") + self.test_torchscript( model_hypers=model_hypers, dataset_info=dataset_info_spherical, diff --git a/tests/utils/test_architectures.py b/tests/utils/test_architectures.py index 8b82b0effa..bb233a8ca5 100644 --- a/tests/utils/test_architectures.py +++ b/tests/utils/test_architectures.py @@ -23,17 +23,17 @@ def is_None(*args, **kwargs) -> None: def test_find_all_architectures(): all_arches = find_all_architectures() - - assert len(all_arches) == 8 + assert len(all_arches) == 9 assert "gap" in all_arches assert "pet" in all_arches assert "soap_bpnn" in all_arches - assert "deprecated.nanopet" in all_arches + assert "experimental.dpa3" in all_arches assert "experimental.flashmd" in all_arches assert "experimental.classifier" in all_arches - assert "llpr" in all_arches assert "experimental.mace" in all_arches + assert "deprecated.nanopet" in all_arches + assert "llpr" in all_arches def test_get_architecture_path(): diff --git a/tox.ini b/tox.ini index 7b39199e67..9316e95bb5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,8 +1,8 @@ [tox] min_version = 4.0 requires = - tox>=4.31 - tox-uv>=1.23 + tox>=4.31 + tox-uv>=1.23 # these are the default environments, i.e. the list of tests running when you # execute `tox` in the command-line without anything else envlist = @@ -13,6 +13,7 @@ envlist = soap-bpnn-tests pet-tests nanopet-tests + dpa3-tests flashmd-tests llpr-tests classifier-tests @@ -156,6 +157,16 @@ changedir = src/metatrain/pet/tests/ commands = pytest {posargs} +[testenv:dpa3-tests] +description = Run DPA3 tests with pytest +passenv = * +deps = + pytest +extras = dpa3 +changedir = src/metatrain/experimental/dpa3/tests/ +commands = + pytest {posargs} + [testenv:gap-tests] description = Run GAP tests with pytest passenv = *