From 81ea25bc46395e780fa9b76246f19852800a4013 Mon Sep 17 00:00:00 2001 From: liwentao Date: Wed, 13 Aug 2025 06:33:17 +0000 Subject: [PATCH 01/18] add dpa3 architecture --- docs/static/qm9/options.yaml | 4 +- examples/basic_usage/run_dpa3.sh | 9 + pyproject.toml | 3 + src/metatrain/dpa3/__init__.py | 6 + src/metatrain/dpa3/default-hypers.yaml | 75 ++++ src/metatrain/dpa3/model.py | 455 ++++++++++++++++++++ src/metatrain/dpa3/trainer.py | 556 +++++++++++++++++++++++++ 7 files changed, 1106 insertions(+), 2 deletions(-) create mode 100644 examples/basic_usage/run_dpa3.sh create mode 100644 src/metatrain/dpa3/__init__.py create mode 100644 src/metatrain/dpa3/default-hypers.yaml create mode 100644 src/metatrain/dpa3/model.py create mode 100644 src/metatrain/dpa3/trainer.py diff --git a/docs/static/qm9/options.yaml b/docs/static/qm9/options.yaml index c0c0ddf7bf..b72b041c1f 100644 --- a/docs/static/qm9/options.yaml +++ b/docs/static/qm9/options.yaml @@ -1,8 +1,8 @@ # architecture used to train the model architecture: - name: soap_bpnn + name: dpa3 training: - num_epochs: 5 # a very short training run + num_epochs: 200 # a very short training run batch_size: 10 # Mandatory section defining the parameters for system and target data of the diff --git a/examples/basic_usage/run_dpa3.sh b/examples/basic_usage/run_dpa3.sh new file mode 100644 index 0000000000..325168d77f --- /dev/null +++ b/examples/basic_usage/run_dpa3.sh @@ -0,0 +1,9 @@ +export METATENSOR_DEBUG_EXTENSIONS_LOADING=1 + +mtt train options.yaml + +package_dir=$(python -c "import site; print(site.getsitepackages()[0])") +cp $package_dir/deepmd/lib/*.so extensions/deepmd/lib/ +cp $package_dir/deepmd_kit.libs/*.so* extensions/deepmd/lib/ + +mtt eval model.pt eval.yaml -e extensions/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 47990590f7..49d1d4aa01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,9 @@ gap = [ "skmatter", "scipy", ] +dpa3 = [ + "deepmd-kit>=3.1.0" +] [tool.check-manifest] ignore = ["src/metatrain/_version.py"] diff --git a/src/metatrain/dpa3/__init__.py b/src/metatrain/dpa3/__init__.py new file mode 100644 index 0000000000..fdd4575b92 --- /dev/null +++ b/src/metatrain/dpa3/__init__.py @@ -0,0 +1,6 @@ +from .model import DPA3 +from .trainer import Trainer + + +__model__ = DPA3 +__trainer__ = Trainer \ No newline at end of file diff --git a/src/metatrain/dpa3/default-hypers.yaml b/src/metatrain/dpa3/default-hypers.yaml new file mode 100644 index 0000000000..7d6143dda5 --- /dev/null +++ b/src/metatrain/dpa3/default-hypers.yaml @@ -0,0 +1,75 @@ +architecture: + name: dpa3 + model: + type_map: + - "H" + - "C" + - "N" + - "O" + descriptor: + type: "dpa3" + repflow: + n_dim: 128 + e_dim: 64 + a_dim: 32 + nlayers: 6 + e_rcut: 6.0 + e_rcut_smth: 5.3 + e_sel: 1200 + a_rcut: 4.0 + a_rcut_smth: 3.5 + a_sel: 300 + axis_neuron: 4 + skip_stat: true + a_compress_rate: 1 + a_compress_e_rate: 2 + a_compress_use_split: true + update_angle: true + update_style: "res_residual" + update_residual: 0.1 + update_residual_init: "const" + smooth_edge_update: true + use_dynamic_sel: true + sel_reduce_factor: 10.0 + activation_function: "custom_silu:10.0" + use_tebd_bias: false + precision: "float64" + concat_output_tebd: false + fitting_net: + neuron: + - 240 + - 240 + - 240 + resnet_dt: true + seed: 1 + precision: "float64" + activation_function: "custom_silu:10.0" + type: "ener" + numb_fparam: 0 + numb_aparam: 0 + dim_case_embd: 0 + trainable: true + rcond: null + atom_ener: [] + use_aparam_as_mask: false + training: + distributed: false + distributed_port: 39591 + batch_size: 8 + num_epochs: 100 + learning_rate: 0.001 + early_stopping_patience: 200 + scheduler_patience: 100 + scheduler_factor: 0.8 + log_interval: 1 + checkpoint_interval: 25 + scale_targets: true + fixed_composition_weights: {} + per_structure_targets: [] + log_mae: false + log_separate_blocks: false + best_model_metric: rmse_prod + loss: + type: mse + weights: {} + reduction: mean \ No newline at end of file diff --git a/src/metatrain/dpa3/model.py b/src/metatrain/dpa3/model.py new file mode 100644 index 0000000000..450c22fd3a --- /dev/null +++ b/src/metatrain/dpa3/model.py @@ -0,0 +1,455 @@ +from typing import Any, Dict, List, Literal, Optional +import copy +import metatensor.torch as mts +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.learn.nn import Linear as LinearMap +from metatensor.torch.learn.nn import ModuleMap +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) + +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, OldCompositionModel +from metatrain.utils.data import TargetInfo +from metatrain.utils.data.dataset import DatasetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler +from metatrain.utils.sum_over_atoms import sum_over_atoms + +from deepmd.pt.model.model import get_standard_model + +# Data processing +def concatenate_structures( + systems: List[System] +): + device = systems[0].positions.device + positions = [] + species = [] + cells = [] + atom_nums: List[int] = [] + node_counter = 0 + + atom_index_list: List[torch.Tensor] = [] + system_index_list: List[torch.Tensor] = [] + + for i, system in enumerate(systems): + atom_nums.append(len(system.positions)) + atom_index_list.append(torch.arange(start=0, end=len(system.positions))) + system_index_list.append(torch.full((len(system.positions),), i)) + max_atom_num = max(atom_nums) + atom_index = torch.cat(atom_index_list, dim=0).to(torch.int32).to(device) + system_index = torch.cat(system_index_list, dim=0).to(torch.int32).to(device) + + positions = torch.zeros((len(systems), max_atom_num, 3), dtype=systems[0].positions.dtype) + species = torch.full((len(systems), max_atom_num), -1, dtype=systems[0].types.dtype) + cells = torch.stack([system.cell for system in systems]) # 形状为 [batch_size, 3, 3] 或相应的晶胞形状 + + + for i, system in enumerate(systems): + positions[i, :len(system.positions)] = system.positions + species[i, :len(system.positions)] = system.types + cells[i] = system.cell + node_counter += len(system.positions) + + + return ( + positions.to(device), + species.to(device), + cells.to(device), + atom_index, + system_index + ) + + +# Model definition +class DPA3(ModelInterface): + __checkpoint_version__ = 1 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float64, torch.float32] + __default_metadata__ = ModelMetadata( + references={ + "implementation": [ + "https://github.com/deepmodeling/deepmd-kit", + ], + "architecture": [ + "DPA3: https://arxiv.org/abs/2506.01686", + ], + } + ) + + component_labels: Dict[str, List[List[Labels]]] # torchscript needs this + + def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info) + self.atomic_types = dataset_info.atomic_types + self.model = get_standard_model(hypers) + + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + self.outputs = { + "features": ModelOutput(unit="", per_atom=True) + } + self.single_label = Labels.single() + + self.num_properties: Dict[str, Dict[str, int]] = {} # by target and block + self.basis_calculators = torch.nn.ModuleDict({}) + self.heads = torch.nn.ModuleDict({}) + self.last_layers = torch.nn.ModuleDict({}) + self.key_labels: Dict[str, Labels] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.property_labels: Dict[str, List[Labels]] = {} + for target_name, target in dataset_info.targets.items(): + self._add_output(target_name, target) + + composition_model = OldCompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if OldCompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + self.additive_models = torch.nn.ModuleList(additive_models) + + self.reverse_precision_dict ={ + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", + torch.int32: "int32", + torch.int64: "int64", + torch.bfloat16: "bfloat16", + torch.bool: "bool", + } + + def _input_type_cast( + self, + coord: torch.Tensor, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + str, + ]: + """Cast the input data to global float type.""" + input_prec = self.reverse_precision_dict[coord.dtype] + + _lst: list[Optional[torch.Tensor]] = [ + vv.to(coord.dtype) if vv is not None else None + for vv in [box, fparam, aparam] + ] + box, fparam, aparam = _lst + if ( + input_prec + == self.reverse_precision_dict[self.global_pt_float_precision] + ): + return coord, box, fparam, aparam, input_prec + else: + pp = torch.float32 + return ( + coord.to(pp), + box.to(pp) if box is not None else None, + fparam.to(pp) if fparam is not None else None, + aparam.to(pp) if aparam is not None else None, + input_prec, + ) + + + + def _add_output(self, target_name: str, target: TargetInfo) -> None: + self.num_properties[target_name] = {} + ll_features_name = ( + f"mtt::aux::{target_name.replace('mtt::', '')}_last_layer_features" + ) + self.outputs[ll_features_name] = ModelOutput(per_atom=True) + self.key_labels[target_name] = target.layout.keys + self.component_labels[target_name] = [ + block.components for block in target.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target.layout.blocks() + ] + + self.outputs[target_name] = ModelOutput( + quantity=target.quantity, + unit=target.unit, + per_atom=True, + ) + + def get_rcut(self): + return self.model.atomic_model.get_rcut() + + def get_sel(self): + return self.model.atomic_model.get_sel() + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + + device = systems[0].positions.device + system_indices = torch.concatenate( + [ + torch.full( + (len(system),), + i_system, + device=device, + ) + for i_system, system in enumerate(systems) + ], + ) + + return_dict: Dict[str, TensorMap] = {} + + sample_values = torch.stack( + [ + system_indices, + torch.concatenate( + [ + torch.arange( + len(system), + device=device, + ) + for system in systems + ], + ), + ], + dim=1, + ) + + ( + positions, + species, + cells, + atom_index, + system_index + ) = concatenate_structures(systems) + + positions = positions.to(torch.float64) + type_to_index = {atomic_type: idx for idx, atomic_type in enumerate(self.atomic_types)} + type_to_index[-1] = -1 + + atype = torch.tensor( + [[type_to_index[s.item()] for s in row] for row in species], + dtype=torch.int32 + ).to(positions.device) + atype = atype.to(torch.int32) + + if torch.all(cells == 0).item(): + box = None + else: + box = cells + + model_ret = self.model.forward_common( + positions, + atype, + box, + fparam=None, + aparam=None, + do_atomic_virial=False, + ) + + if self.model.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.model.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + + else: + model_predict["force"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + else: + model_predict = model_ret + model_predict["updated_coord"] += positions + + atomic_properties: Dict[str, TensorMap] = {} + blocks: List[TensorBlock] = [] + + system_col = system_index + atom_col = atom_index + + values = torch.stack([system_col, atom_col], dim=0).transpose(0, 1) + invariant_coefficients = Labels( + names=["system", "atom"], + values=values.to(device) + ) + + + mask = torch.abs(model_predict["atom_energy"]) > 1e-10 + atomic_property_tensor = model_predict["atom_energy"][mask].unsqueeze(-1) + + blocks.append(TensorBlock( + values=atomic_property_tensor, + samples=invariant_coefficients, + components=self.component_labels["energy"][0], + properties=self.property_labels["energy"][0].to(device), + )) + + atomic_properties["energy"] = TensorMap(self.key_labels["energy"].to(device), blocks) + + + for output_name, atomic_property in atomic_properties.items(): + + if outputs[output_name].per_atom: + return_dict[output_name] = atomic_property + else: + # sum the atomic property to get the total property + return_dict[output_name] = sum_over_atoms(atomic_property) + + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler(return_dict) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + return_dict[name] = mts.add( + return_dict[name], + additive_contributions[name], + ) + + return return_dict + + + def restart(self, dataset_info: DatasetInfo) -> "DPA3": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The DPA3 model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if OldCompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler.restart(dataset_info) + + return self + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "DPA3": + model_data = checkpoint["model_data"] + + if context == "restart": + model_state_dict = checkpoint["model_state_dict"] + elif context == "finetune" or context == "export": + model_state_dict = checkpoint["best_model_state_dict"] + if model_state_dict is None: + model_state_dict = checkpoint["model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + dtype = next(iter(model_state_dict.values())).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + + # Loading the metadata from the checkpoint + metadata = checkpoint.get("metadata", None) + if metadata is not None: + model.__default_metadata__ = metadata + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {self.dtype} for DPA3") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This funciton moves them: + + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [self.hypers['descriptor']['repflow']['e_rcut']] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + if metadata is None: + metadata = self.__default_metadata__ + else: + metadata = merge_metadata(self.__default_metadata__, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + @staticmethod + def upgrade_checkpoint(checkpoint: Dict) -> Dict: + raise NotImplementedError("checkpoint upgrade is not implemented for DPA3") + + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs \ No newline at end of file diff --git a/src/metatrain/dpa3/trainer.py b/src/metatrain/dpa3/trainer.py new file mode 100644 index 0000000000..f6189b3839 --- /dev/null +++ b/src/metatrain/dpa3/trainer.py @@ -0,0 +1,556 @@ +import copy +import logging +from pathlib import Path +from typing import Any, Dict, List, Literal, Union + +import torch +import torch.distributed +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import TrainerInterface +from metatrain.utils.additive import remove_additive +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + _is_disk_dataset, +) +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.external_naming import to_external_name +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import TensorMapDictLoss +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import remove_scale +from metatrain.utils.transfer import ( + batch_to, +) + +from .model import DPA3 + + +class Trainer(TrainerInterface): + __checkpoint_version__ = 1 + + def __init__(self, hypers): + super().__init__(hypers) + + self.optimizer_state_dict = None + self.scheduler_state_dict = None + self.epoch = None + self.best_metric = None + self.best_model_state_dict = None + self.best_optimizer_state_dict = None + + def train( + self, + model: DPA3, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ): + assert dtype in DPA3.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + + if is_distributed: + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + torch.distributed.init_process_group(backend="nccl") + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with SOAP-BPNN, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + else: + device = devices[ + 0 + ] # only one device, as we don't support multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Calculate the neighbor lists in advance (in particular, this + # needs to happen before the additive models are trained, as they + # might need them): + logging.info("Calculating neighbor lists for the datasets") + requested_neighbor_lists = get_requested_neighbor_lists(model) + for dataset in train_datasets + val_datasets: + # If the dataset is a disk dataset, the NLs are already attached, we will + # just check the first system + if _is_disk_dataset(dataset): + system = dataset[0]["system"] + for options in requested_neighbor_lists: + if options not in system.known_neighbor_lists(): + raise ValueError( + "The requested neighbor lists are not attached to the " + f"system. Neighbor list {options} is missing from the " + "first system in the disk dataset. Make sure you save " + "the neighbor lists in the systems when saving the dataset." + ) + else: + for sample in dataset: + system = sample["system"] + # The following line attaches the neighbors lists to the system, + # and doesn't require to reassign the system to the dataset: + get_system_with_neighbor_lists(system, requested_neighbor_lists) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of the SOAP-BPNN are always in float64 (to avoid + # numerical errors in the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["fixed_composition_weights"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, model.additive_models, treat_as_additive=True + ) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Create a collate function: + targets_keys = list( + (model.module if is_distributed else model).dataset_info.targets.keys() + ) + collate_fn = CollateFn(target_keys=targets_keys) + + # Create dataloader for the training datasets: + train_dataloaders = [] + for train_dataset, train_sampler in zip(train_datasets, train_samplers): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=False) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers): + if len(val_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A validation dataset has fewer samples " + f"({len(val_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + # Extract all the possible outputs and their gradients: + train_targets = (model.module if is_distributed else model).dataset_info.targets + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + # Create a loss weight dict: + loss_weights_dict = {} + for output_name in outputs_list: + loss_weights_dict[output_name] = ( + self.hypers["loss"]["weights"][ + to_external_name(output_name, train_targets) + ] + if to_external_name(output_name, train_targets) + in self.hypers["loss"]["weights"] + else 1.0 + ) + loss_weights_dict_external = { + to_external_name(key, train_targets): value + for key, value in loss_weights_dict.items() + } + loss_hypers = copy.deepcopy(self.hypers["loss"]) + loss_hypers["weights"] = loss_weights_dict + logging.info(f"Training with loss weights: {loss_weights_dict_external}") + + # Create a loss function: + loss_fn = TensorMapDictLoss( + **loss_hypers, + ) + + # Create an optimizer: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + if self.optimizer_state_dict is not None: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a scheduler: + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=self.hypers["scheduler_factor"], + patience=self.hypers["scheduler_patience"], + threshold=0.001, + min_lr=1e-5, + ) + if self.scheduler_state_dict is not None: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + # per-atom targets: + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + old_lr = optimizer.param_groups[0]["lr"] + logging.info(f"Initial learning rate: {old_lr}") + + start_epoch = 0 if self.epoch is None else self.epoch + 1 + + # Train the model: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Starting training") + epoch = start_epoch + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + + for batch in train_dataloader: + optimizer.zero_grad() + + systems, targets, extra_data = batch + systems, targets, extra_data = batch_to( + systems, targets, extra_data, device=device + ) + for additive_model in ( + model.module if is_distributed else model + ).additive_models: + targets = remove_additive( + systems, targets, additive_model, train_targets + ) + targets = remove_scale( + targets, (model.module if is_distributed else model).scaler + ) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype + ) + + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + + train_loss_batch = loss_fn(predictions, targets) + + train_loss_batch.backward() + optimizer.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + train_rmse_calculator.update(predictions, targets) + if self.hypers["log_mae"]: + train_mae_calculator.update(predictions, targets) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + val_loss = 0.0 + for batch in val_dataloader: + systems, targets, extra_data = batch + systems, targets, extra_data = batch_to( + systems, targets, extra_data, device=device + ) + for additive_model in ( + model.module if is_distributed else model + ).additive_models: + targets = remove_additive( + systems, targets, additive_model, train_targets + ) + targets = remove_scale( + targets, (model.module if is_distributed else model).scaler + ) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype + ) + + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + + val_loss_batch = loss_fn(predictions, targets) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + val_rmse_calculator.update(predictions, targets) + if self.hypers["log_mae"]: + val_mae_calculator.update(predictions, targets) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = {"loss": train_loss, **finalized_train_info} + finalized_val_info = {"loss": val_loss, **finalized_val_info} + + if epoch == start_epoch: + scaler_scales = ( + model.module if is_distributed else model + ).scaler.get_scales_dict() + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + scales={ + key: ( + scaler_scales[key.split(" ")[0]] + if ("MAE" in key or "RMSE" in key) + else 1.0 + ) + for key in finalized_train_info.keys() + }, + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + ) + + lr_scheduler.step(val_loss) + new_lr = lr_scheduler.get_last_lr()[0] + if new_lr != old_lr: + if new_lr < 1e-7: + logging.info("Learning rate is too small, stopping training") + break + else: + logging.info(f"Changing learning rate from {old_lr} to {new_lr}") + old_lr = new_lr + # load best model and optimizer state dict, re-initialize scheduler + (model.module if is_distributed else model).load_state_dict( + self.best_model_state_dict + ) + optimizer.load_state_dict(self.best_optimizer_state_dict) + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=self.hypers["scheduler_factor"], + patience=self.hypers["scheduler_patience"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + def save_checkpoint(self, model, path: Union[str, Path]): + checkpoint = { + "architecture_name": "soap_bpnn", + "model_ckpt_version": model.__checkpoint_version__, + "trainer_ckpt_version": self.__checkpoint_version__, + "metadata": model.__default_metadata__, + "model_data": { + "model_hypers": model.hypers, + "dataset_info": model.dataset_info, + }, + "model_state_dict": model.state_dict(), + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: Dict[str, Any], + context: Literal["restart", "finetune"], # not used at the moment + ) -> "Trainer": + epoch = checkpoint["epoch"] + optimizer_state_dict = checkpoint["optimizer_state_dict"] + scheduler_state_dict = checkpoint["scheduler_state_dict"] + best_metric = checkpoint["best_metric"] + best_model_state_dict = checkpoint["best_model_state_dict"] + best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + # Create the trainer + trainer = cls(hypers) + trainer.optimizer_state_dict = optimizer_state_dict + trainer.scheduler_state_dict = scheduler_state_dict + trainer.epoch = epoch + trainer.best_metric = best_metric + trainer.best_model_state_dict = best_model_state_dict + trainer.best_optimizer_state_dict = best_optimizer_state_dict + + return trainer + + @staticmethod + def upgrade_checkpoint(checkpoint: Dict) -> Dict: + raise NotImplementedError("checkpoint upgrade is not implemented for SoapBPNN") From 13b85003ee6842571903792a84807d885561f0c3 Mon Sep 17 00:00:00 2001 From: liwentao Date: Wed, 13 Aug 2025 07:09:50 +0000 Subject: [PATCH 02/18] fix dpa3 interface & add dependency --- pyproject.toml | 6 +++--- src/metatrain/dpa3/model.py | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 49d1d4aa01..0f071291ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,9 @@ requires = [ build-backend = "setuptools.build_meta" [project.optional-dependencies] +dpa3 = [ + "deepmd-kit>=3.1.0" +] soap-bpnn = [ "torch-spex>=0.1,<0.2", "wigners", @@ -79,9 +82,6 @@ gap = [ "skmatter", "scipy", ] -dpa3 = [ - "deepmd-kit>=3.1.0" -] [tool.check-manifest] ignore = ["src/metatrain/_version.py"] diff --git a/src/metatrain/dpa3/model.py b/src/metatrain/dpa3/model.py index 450c22fd3a..3689a661c2 100644 --- a/src/metatrain/dpa3/model.py +++ b/src/metatrain/dpa3/model.py @@ -88,7 +88,7 @@ class DPA3(ModelInterface): component_labels: Dict[str, List[List[Labels]]] # torchscript needs this def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: - super().__init__(hypers, dataset_info) + super().__init__(hypers, dataset_info, self.__default_metadata__) self.atomic_types = dataset_info.atomic_types self.model = get_standard_model(hypers) @@ -451,5 +451,21 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: def upgrade_checkpoint(checkpoint: Dict) -> Dict: raise NotImplementedError("checkpoint upgrade is not implemented for DPA3") + def get_checkpoint(self) -> Dict: + model_state_dict = self.state_dict() + model_state_dict["finetune_config"] = self.finetune_config + checkpoint = { + "architecture_name": "dpa3", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "model_state_dict": model_state_dict, + "best_model_state_dict": None, + } + return checkpoint + def supported_outputs(self) -> Dict[str, ModelOutput]: return self.outputs \ No newline at end of file From 385170b0d645bc8b81afd22f94958b0899cfecac Mon Sep 17 00:00:00 2001 From: liwentao Date: Fri, 15 Aug 2025 11:30:13 +0000 Subject: [PATCH 03/18] add unit tests --- docs/static/qm9/options.yaml | 4 +- examples/basic_usage/run_dpa3.sh | 2 +- .../{ => experimental}/dpa3/__init__.py | 0 .../dpa3/default-hypers.yaml | 0 .../{ => experimental}/dpa3/model.py | 86 ++++++++--- .../experimental/dpa3/tests/__init__.py | 12 ++ .../dpa3/tests/test_checkpoints.py | 134 +++++++++++++++++ .../experimental/dpa3/tests/test_continue.py | 106 +++++++++++++ .../dpa3/tests/test_regression.py | 141 ++++++++++++++++++ .../dpa3/tests/test_torchscript.py | 91 +++++++++++ .../{ => experimental}/dpa3/trainer.py | 16 +- 11 files changed, 567 insertions(+), 25 deletions(-) rename src/metatrain/{ => experimental}/dpa3/__init__.py (100%) rename src/metatrain/{ => experimental}/dpa3/default-hypers.yaml (100%) rename src/metatrain/{ => experimental}/dpa3/model.py (84%) create mode 100644 src/metatrain/experimental/dpa3/tests/__init__.py create mode 100644 src/metatrain/experimental/dpa3/tests/test_checkpoints.py create mode 100644 src/metatrain/experimental/dpa3/tests/test_continue.py create mode 100644 src/metatrain/experimental/dpa3/tests/test_regression.py create mode 100644 src/metatrain/experimental/dpa3/tests/test_torchscript.py rename src/metatrain/{ => experimental}/dpa3/trainer.py (97%) diff --git a/docs/static/qm9/options.yaml b/docs/static/qm9/options.yaml index b72b041c1f..b32a40d94f 100644 --- a/docs/static/qm9/options.yaml +++ b/docs/static/qm9/options.yaml @@ -1,8 +1,8 @@ # architecture used to train the model architecture: - name: dpa3 + name: experimental.dpa3 training: - num_epochs: 200 # a very short training run + num_epochs: 2 # a very short training run batch_size: 10 # Mandatory section defining the parameters for system and target data of the diff --git a/examples/basic_usage/run_dpa3.sh b/examples/basic_usage/run_dpa3.sh index 325168d77f..98e3a8ea8b 100644 --- a/examples/basic_usage/run_dpa3.sh +++ b/examples/basic_usage/run_dpa3.sh @@ -1,6 +1,6 @@ export METATENSOR_DEBUG_EXTENSIONS_LOADING=1 -mtt train options.yaml +# mtt train options.yaml package_dir=$(python -c "import site; print(site.getsitepackages()[0])") cp $package_dir/deepmd/lib/*.so extensions/deepmd/lib/ diff --git a/src/metatrain/dpa3/__init__.py b/src/metatrain/experimental/dpa3/__init__.py similarity index 100% rename from src/metatrain/dpa3/__init__.py rename to src/metatrain/experimental/dpa3/__init__.py diff --git a/src/metatrain/dpa3/default-hypers.yaml b/src/metatrain/experimental/dpa3/default-hypers.yaml similarity index 100% rename from src/metatrain/dpa3/default-hypers.yaml rename to src/metatrain/experimental/dpa3/default-hypers.yaml diff --git a/src/metatrain/dpa3/model.py b/src/metatrain/experimental/dpa3/model.py similarity index 84% rename from src/metatrain/dpa3/model.py rename to src/metatrain/experimental/dpa3/model.py index 3689a661c2..4c760cbd6f 100644 --- a/src/metatrain/dpa3/model.py +++ b/src/metatrain/experimental/dpa3/model.py @@ -26,6 +26,20 @@ from deepmd.pt.model.model import get_standard_model +def update_v1_v2(state_dict): + # This if-statement is necessary to handle cases when + # best_model_state_dict and model_state_dict are the same. + # In that case, the both are updated within the first call of + # this function in the PET.update_checkpoint() method. + if ( + state_dict is not None + and "additive_models.0.model.type_to_index" not in state_dict + ): + state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( + "additive_models.0.type_to_index" + ) + + # Data processing def concatenate_structures( systems: List[System] @@ -90,6 +104,14 @@ class DPA3(ModelInterface): def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: super().__init__(hypers, dataset_info, self.__default_metadata__) self.atomic_types = dataset_info.atomic_types + + self.requested_nl = NeighborListOptions( + cutoff=self.hypers["descriptor"]["repflow"]["e_rcut"], + full_list=True, + strict=True, + ) + self.targets_keys = list(dataset_info.targets.keys())[0] + self.model = get_standard_model(hypers) self.scaler = Scaler(hypers={}, dataset_info=dataset_info) @@ -98,10 +120,8 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: } self.single_label = Labels.single() - self.num_properties: Dict[str, Dict[str, int]] = {} # by target and block - self.basis_calculators = torch.nn.ModuleDict({}) - self.heads = torch.nn.ModuleDict({}) - self.last_layers = torch.nn.ModuleDict({}) + self.num_properties: Dict[str, Dict[str, int]] = {} + self.key_labels: Dict[str, Labels] = {} self.component_labels: Dict[str, List[List[Labels]]] = {} self.property_labels: Dict[str, List[Labels]] = {} @@ -196,6 +216,9 @@ def get_rcut(self): def get_sel(self): return self.model.atomic_model.get_sel() + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [self.requested_nl] def forward( self, @@ -205,6 +228,8 @@ def forward( ) -> Dict[str, TensorMap]: device = systems[0].positions.device + position_dtype = systems[0].positions.dtype + atype_dtype = systems[0].types.dtype system_indices = torch.concatenate( [ torch.full( @@ -216,6 +241,24 @@ def forward( ], ) + if self.single_label.values.device != device: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [labels.to(device) for labels in components_block] + for components_block in components_tmap + ] + for output_name, components_tmap in self.component_labels.items() + } + self.property_labels = { + output_name: [labels.to(device) for labels in properties_tmap] + for output_name, properties_tmap in self.property_labels.items() + } + return_dict: Dict[str, TensorMap] = {} sample_values = torch.stack( @@ -242,15 +285,15 @@ def forward( system_index ) = concatenate_structures(systems) - positions = positions.to(torch.float64) + positions = positions.to(position_dtype) type_to_index = {atomic_type: idx for idx, atomic_type in enumerate(self.atomic_types)} type_to_index[-1] = -1 atype = torch.tensor( [[type_to_index[s.item()] for s in row] for row in species], - dtype=torch.int32 + dtype=atype_dtype ).to(positions.device) - atype = atype.to(torch.int32) + atype = atype.to(atype_dtype) if torch.all(cells == 0).item(): box = None @@ -300,11 +343,11 @@ def forward( blocks.append(TensorBlock( values=atomic_property_tensor, samples=invariant_coefficients, - components=self.component_labels["energy"][0], - properties=self.property_labels["energy"][0].to(device), + components=self.component_labels[self.targets_keys][0], + properties=self.property_labels[self.targets_keys][0].to(device), )) - atomic_properties["energy"] = TensorMap(self.key_labels["energy"].to(device), blocks) + atomic_properties[self.targets_keys] = TensorMap(self.key_labels[self.targets_keys].to(device), blocks) for output_name, atomic_property in atomic_properties.items(): @@ -333,7 +376,7 @@ def forward( return_dict[name], additive_contributions[name], ) - + return return_dict @@ -447,13 +490,22 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: return AtomisticModel(self.eval(), metadata, capabilities) - @staticmethod - def upgrade_checkpoint(checkpoint: Dict) -> Dict: - raise NotImplementedError("checkpoint upgrade is not implemented for DPA3") + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + if checkpoint["model_ckpt_version"] == 1: + checkpoints.update_v1_v2(checkpoint["model_state_dict"]) + checkpoints.update_v1_v2(checkpoint["best_model_state_dict"]) + checkpoint["model_ckpt_version"] = 2 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current " + f"model version is {cls.__checkpoint_version__}." + ) + return checkpoint def get_checkpoint(self) -> Dict: - model_state_dict = self.state_dict() - model_state_dict["finetune_config"] = self.finetune_config checkpoint = { "architecture_name": "dpa3", "model_ckpt_version": self.__checkpoint_version__, @@ -462,7 +514,7 @@ def get_checkpoint(self) -> Dict: "model_hypers": self.hypers, "dataset_info": self.dataset_info, }, - "model_state_dict": model_state_dict, + "model_state_dict": self.state_dict(), "best_model_state_dict": None, } return checkpoint diff --git a/src/metatrain/experimental/dpa3/tests/__init__.py b/src/metatrain/experimental/dpa3/tests/__init__.py new file mode 100644 index 0000000000..b997a00c2e --- /dev/null +++ b/src/metatrain/experimental/dpa3/tests/__init__.py @@ -0,0 +1,12 @@ +from pathlib import Path + +from metatrain.utils.architectures import get_default_hypers + + +DATASET_PATH = str(Path(__file__).parents[4] / "tests/resources/qm9_reduced_100.xyz") +DATASET_WITH_FORCES_PATH = str( + Path(__file__).parents[4] / "tests/resources/carbon_reduced_100.xyz" +) + +DEFAULT_HYPERS = get_default_hypers("dpa3") +MODEL_HYPERS = DEFAULT_HYPERS["model"] diff --git a/src/metatrain/experimental/dpa3/tests/test_checkpoints.py b/src/metatrain/experimental/dpa3/tests/test_checkpoints.py new file mode 100644 index 0000000000..de97c89e58 --- /dev/null +++ b/src/metatrain/experimental/dpa3/tests/test_checkpoints.py @@ -0,0 +1,134 @@ +import copy + +import pytest +import torch + +from metatrain.dpa3 import DPA3, Trainer +from metatrain.utils.data import ( + DatasetInfo, + get_atomic_types, + get_dataset, +) +from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.testing.checkpoints import ( + checkpoint_did_not_change, + make_checkpoint_load_tests, +) + +from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS + +# from pathlib import Path + +# from metatrain.utils.architectures import get_default_hypers + +# DATASET_PATH = str(Path(__file__).parents[4] / "tests/resources/qm9_reduced_100.xyz") +# DATASET_WITH_FORCES_PATH = str( +# Path(__file__).parents[4] / "tests/resources/carbon_reduced_100.xyz" +# ) + +# DEFAULT_HYPERS = get_default_hypers("dpa3") +# MODEL_HYPERS = DEFAULT_HYPERS["model"] + + +@pytest.fixture(scope="module") +def model_trainer(): + energy_target = { + "quantity": "energy", + "read_from": DATASET_PATH, + "reader": "ase", + "key": "U0", + "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, + "forces": False, + "stress": False, + "virial": False, + } + + dataset, targets_info, _ = get_dataset( + { + "systems": { + "read_from": DATASET_PATH, + "reader": "ase", + }, + "targets": { + "energy": energy_target, + }, + } + ) + + dataset_info = DatasetInfo( + length_unit="", + atomic_types=get_atomic_types(dataset), + targets=targets_info, + ) + + # minimize the size of the checkpoint on disk + hypers = copy.deepcopy(MODEL_HYPERS) + + + model = DPA3(hypers, dataset_info) + + hypers = copy.deepcopy(DEFAULT_HYPERS) + hypers["training"]["num_epochs"] = 1 + trainer = Trainer(hypers["training"]) + + trainer.train( + model, + dtype=model.__supported_dtypes__[0], + devices=[torch.device("cpu")], + train_datasets=[dataset], + val_datasets=[dataset], + checkpoint_dir="", + ) + + return model, trainer + + +test_checkpoint_did_not_change = checkpoint_did_not_change + +test_loading_old_checkpoints = make_checkpoint_load_tests(DEFAULT_HYPERS) + + +@pytest.mark.parametrize("context", ["finetune", "restart", "export"]) +def test_get_checkpoint(context): + """ + Test that the checkpoint created by the model.get_checkpoint() + function can be loaded back in all possible contexts. + """ + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={"energy": get_energy_target_info({"unit": "eV"})}, + ) + model = DPA3(MODEL_HYPERS, dataset_info) + checkpoint = model.get_checkpoint() + DPA3.load_checkpoint(checkpoint, context) + + +@pytest.mark.parametrize("cls_type", ["model", "trainer"]) +def test_failed_checkpoint_upgrade(cls_type): + """Test error raised when trying to upgrade an invalid checkpoint version.""" + checkpoint = {f"{cls_type}_ckpt_version": 9999} + + if cls_type == "model": + cls = DPA3 + version = DPA3.__checkpoint_version__ + else: + cls = Trainer + version = Trainer.__checkpoint_version__ + + match = ( + f"Unable to upgrade the checkpoint: the checkpoint is using {cls_type} version " + f"9999, while the current {cls_type} version is {version}." + ) + with pytest.raises(RuntimeError, match=match): + cls.upgrade_checkpoint(checkpoint) + +if __name__ == "__main__": + test_get_checkpoint("finetune") + test_get_checkpoint("restart") + test_get_checkpoint("export") + test_failed_checkpoint_upgrade("model") + test_failed_checkpoint_upgrade("trainer") \ No newline at end of file diff --git a/src/metatrain/experimental/dpa3/tests/test_continue.py b/src/metatrain/experimental/dpa3/tests/test_continue.py new file mode 100644 index 0000000000..8bdb951da6 --- /dev/null +++ b/src/metatrain/experimental/dpa3/tests/test_continue.py @@ -0,0 +1,106 @@ +import shutil + +import metatensor +import torch +from omegaconf import OmegaConf + +from metatrain.dpa3 import DPA3, Trainer +from metatrain.utils.data import Dataset, DatasetInfo +from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.io import model_from_checkpoint +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) + +from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS + +def test_continue(monkeypatch, tmp_path): + """Tests that a model can be checkpointed and loaded + for a continuation of the training process""" + + monkeypatch.chdir(tmp_path) + shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") + + systems = read_systems(DATASET_PATH) + systems = [system.to(torch.float32) for system in systems] + + target_info_dict = {} + target_info_dict["mtt::U0"] = get_energy_target_info({"unit": "eV"}) + + dataset_info = DatasetInfo( + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict + ) + model = DPA3(MODEL_HYPERS, dataset_info).to(systems[0].positions.device) + requested_neighbor_lists = get_requested_neighbor_lists(model) + systems = [ + get_system_with_neighbor_lists(system, requested_neighbor_lists) + for system in systems + ] + output_before = model(systems[:5], {"mtt::U0": model.outputs["mtt::U0"]}) + + conf = { + "mtt::U0": { + "quantity": "energy", + "read_from": DATASET_PATH, + "reader": "ase", + "key": "U0", + "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, + "forces": False, + "stress": False, + "virial": False, + } + } + targets, _ = read_targets(OmegaConf.create(conf)) + + # systems in float64 are required for training + systems = [system.to(torch.float64) for system in systems] + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) + + hypers = DEFAULT_HYPERS.copy() + hypers["training"]["num_epochs"] = 0 + trainer = Trainer(hypers["training"]) + trainer.train( + model=model, + dtype=torch.float64, + devices=[torch.device("cpu")], + train_datasets=[dataset], + val_datasets=[dataset], + checkpoint_dir=".", + ) + trainer.save_checkpoint(model, "temp.ckpt") + checkpoint = torch.load("temp.ckpt", weights_only=False, map_location="cpu") + model_after = model_from_checkpoint(checkpoint, context="restart") + assert isinstance(model_after, DPA3) + model_after.restart(dataset_info) + + hypers["training"]["num_epochs"] = 0 + trainer = Trainer(hypers["training"]) + trainer.train( + model=model_after, + dtype=torch.float64, + devices=[torch.device("cpu")], + train_datasets=[dataset], + val_datasets=[dataset], + checkpoint_dir=".", + ) + + # evaluation + systems = [system.to(torch.float32) for system in systems] + + model.eval() + model_after.eval() + + # Predict on the first five systems + output_before = model(systems[:5], {"mtt::U0": model.outputs["mtt::U0"]}) + output_after = model_after(systems[:5], {"mtt::U0": model_after.outputs["mtt::U0"]}) + + assert metatensor.torch.allclose(output_before["mtt::U0"], output_after["mtt::U0"]) + +if __name__ == "__main__": + tmp_path = "/aisi/mnt/data_nas/liwentao/devel_workspace/metatrain/metatrain/src/metatrain/dpa3/tests/tmp" + test_continue(monkeypatch, tmp_path) \ No newline at end of file diff --git a/src/metatrain/experimental/dpa3/tests/test_regression.py b/src/metatrain/experimental/dpa3/tests/test_regression.py new file mode 100644 index 0000000000..a76d34bdf8 --- /dev/null +++ b/src/metatrain/experimental/dpa3/tests/test_regression.py @@ -0,0 +1,141 @@ +import random + +import numpy as np +import torch +from metatomic.torch import ModelOutput +from omegaconf import OmegaConf + +from metatrain.dpa3 import DPA3, Trainer +from metatrain.utils.data import Dataset, DatasetInfo +from metatrain.utils.data.readers import ( + read_systems, + read_targets, +) +from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + +from . import DATASET_PATH, DATASET_WITH_FORCES_PATH, DEFAULT_HYPERS, MODEL_HYPERS + + +def test_regression_init(): + """Regression test for the model at initialization""" + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + + targets = {} + targets["mtt::U0"] = get_energy_target_info({"quantity": "energy", "unit": "eV"}) + + dataset_info = DatasetInfo( + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets + ) + model = DPA3(MODEL_HYPERS, dataset_info).to('cpu') + + # Predict on the first five systems + systems = read_systems(DATASET_PATH)[:5] + systems = [system.to(torch.float64) for system in systems] + for system in systems: + get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + output = model( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + ) + + expected_output = torch.tensor( + [ + [-8.196924407132], + [-6.523280256844], + [-4.913442698461], + [-6.568228343430], + [-4.895156818840], + ],dtype=torch.float64 + ) + + # if you need to change the hardcoded values: + # torch.set_printoptions(precision=12) + # print(output["mtt::U0"].block().values) + + torch.testing.assert_close(output["mtt::U0"].block().values, expected_output) + + +def test_regression_energies_forces_train(): + """Regression test for the model when trained for 2 epoch on a small dataset""" + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + + systems = read_systems(DATASET_WITH_FORCES_PATH) + + conf = { + "energy": { + "quantity": "energy", + "read_from": DATASET_WITH_FORCES_PATH, + "reader": "ase", + "key": "energy", + "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, + "forces": {"read_from": DATASET_WITH_FORCES_PATH, "key": "force"}, + "stress": False, + "virial": False, + } + } + + targets, target_info_dict = read_targets(OmegaConf.create(conf)) + targets = {"energy": targets["energy"]} + dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]}) + hypers = DEFAULT_HYPERS.copy() + hypers["training"]["num_epochs"] = 1 + hypers["training"]["scheduler_patience"] = 1 + hypers["training"]["fixed_composition_weights"] = {} + + dataset_info = DatasetInfo( + length_unit="Angstrom", atomic_types=[6], targets=target_info_dict + ) + model = DPA3(MODEL_HYPERS, dataset_info).to('cpu') + trainer = Trainer(hypers["training"]) + trainer.train( + model=model, + dtype=torch.float64, + devices=[torch.device("cpu")], + train_datasets=[dataset], + val_datasets=[dataset], + checkpoint_dir=".", + ) + + # Predict on the first five systems + systems = [system.to(torch.float64) for system in systems] + for system in systems: + get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + output = evaluate_model( + model, systems[:5], targets=target_info_dict, is_training=False + ) + + expected_output = torch.tensor( + [ + [-0.555318677882], + [-0.569078342763], + [-0.579769296313], + [-0.518369165620], + [-0.556493731493] + ],dtype=torch.float64 + ) + + expected_gradients_output = torch.tensor( + [-0.006725381569, 0.008463345547, 0.025475740380],dtype=torch.float64 + ) + + # if you need to change the hardcoded values: + # torch.set_printoptions(precision=12) + # print(output["energy"].block().values) + # print(output["energy"].block().gradient("positions").values.squeeze(-1)[0]) + + torch.testing.assert_close(output["energy"].block().values, expected_output) + torch.testing.assert_close( + output["energy"].block().gradient("positions").values[0, :, 0], + expected_gradients_output, + ) diff --git a/src/metatrain/experimental/dpa3/tests/test_torchscript.py b/src/metatrain/experimental/dpa3/tests/test_torchscript.py new file mode 100644 index 0000000000..725debc45a --- /dev/null +++ b/src/metatrain/experimental/dpa3/tests/test_torchscript.py @@ -0,0 +1,91 @@ +import copy + +import torch +from metatomic.torch import System + +from metatrain.dpa3 import DPA3 +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + +from . import MODEL_HYPERS + + +def test_torchscript(): + """Tests that the model can be jitted.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + system = System( + types=torch.tensor([6, 1, 8, 7]), + positions=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + model = torch.jit.script(model) + model( + [system], + {"energy": model.outputs["energy"]}, + ) + + +def test_torchscript_save_load(tmpdir): + """Tests that the model can be jitted and saved.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + model = DPA3(MODEL_HYPERS, dataset_info) + + with tmpdir.as_cwd(): + torch.jit.save(torch.jit.script(model), "model.pt") + torch.jit.load("model.pt") + + +def test_torchscript_integers(): + """Tests that the model can be jitted when some float + parameters are instead supplied as integers.""" + + new_hypers = copy.deepcopy(MODEL_HYPERS) + new_hypers["cutoff"] = 5 + new_hypers["cutoff_width"] = 1 + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + + system = System( + types=torch.tensor([6, 1, 8, 7]), + positions=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + model = torch.jit.script(model) + model( + [system], + {"energy": model.outputs["energy"]}, + ) diff --git a/src/metatrain/dpa3/trainer.py b/src/metatrain/experimental/dpa3/trainer.py similarity index 97% rename from src/metatrain/dpa3/trainer.py rename to src/metatrain/experimental/dpa3/trainer.py index f6189b3839..7564df9489 100644 --- a/src/metatrain/dpa3/trainer.py +++ b/src/metatrain/experimental/dpa3/trainer.py @@ -309,7 +309,7 @@ def train( for batch in train_dataloader: optimizer.zero_grad() - + model.to(device) systems, targets, extra_data = batch systems, targets, extra_data = batch_to( systems, targets, extra_data, device=device @@ -504,7 +504,7 @@ def train( def save_checkpoint(self, model, path: Union[str, Path]): checkpoint = { - "architecture_name": "soap_bpnn", + "architecture_name": "dpa3", "model_ckpt_version": model.__checkpoint_version__, "trainer_ckpt_version": self.__checkpoint_version__, "metadata": model.__default_metadata__, @@ -551,6 +551,12 @@ def load_checkpoint( return trainer - @staticmethod - def upgrade_checkpoint(checkpoint: Dict) -> Dict: - raise NotImplementedError("checkpoint upgrade is not implemented for SoapBPNN") + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using trainer " + f"version {checkpoint['trainer_ckpt_version']}, while the current " + f"trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint \ No newline at end of file From d23aaa173d995e3884f2cedbd8d008d277eab869 Mon Sep 17 00:00:00 2001 From: liwentao Date: Fri, 15 Aug 2025 14:57:20 +0000 Subject: [PATCH 04/18] fix unit tests --- .github/workflows/architecture-tests.yml | 2 +- .../experimental/dpa3/default-hypers.yaml | 2 +- src/metatrain/experimental/dpa3/model.py | 9 +- .../experimental/dpa3/schema-hypers.json | 149 ++++++++++ .../experimental/dpa3/tests/__init__.py | 6 +- .../dpa3/tests/test_checkpoints.py | 2 +- .../experimental/dpa3/tests/test_continue.py | 2 +- .../dpa3/tests/test_functionality.py | 273 ++++++++++++++++++ .../dpa3/tests/test_regression.py | 2 +- .../dpa3/tests/test_torchscript.py | 2 +- src/metatrain/experimental/dpa3/trainer.py | 2 +- tox.ini | 11 + 12 files changed, 450 insertions(+), 12 deletions(-) create mode 100644 src/metatrain/experimental/dpa3/schema-hypers.json create mode 100644 src/metatrain/experimental/dpa3/tests/test_functionality.py diff --git a/.github/workflows/architecture-tests.yml b/.github/workflows/architecture-tests.yml index ffcb87c5af..27085c0bfb 100644 --- a/.github/workflows/architecture-tests.yml +++ b/.github/workflows/architecture-tests.yml @@ -18,7 +18,7 @@ jobs: - pet - nanopet - deprecated-pet - + - dpa3 runs-on: ubuntu-22.04 steps: diff --git a/src/metatrain/experimental/dpa3/default-hypers.yaml b/src/metatrain/experimental/dpa3/default-hypers.yaml index 7d6143dda5..871b5e7585 100644 --- a/src/metatrain/experimental/dpa3/default-hypers.yaml +++ b/src/metatrain/experimental/dpa3/default-hypers.yaml @@ -1,5 +1,5 @@ architecture: - name: dpa3 + name: experimental.dpa3 model: type_map: - "H" diff --git a/src/metatrain/experimental/dpa3/model.py b/src/metatrain/experimental/dpa3/model.py index 4c760cbd6f..2ffd0c56d7 100644 --- a/src/metatrain/experimental/dpa3/model.py +++ b/src/metatrain/experimental/dpa3/model.py @@ -349,7 +349,12 @@ def forward( atomic_properties[self.targets_keys] = TensorMap(self.key_labels[self.targets_keys].to(device), blocks) - + if selected_atoms is not None: + for output_name, tmap in atomic_properties.items(): + atomic_properties[output_name] = mts.slice( + tmap, axis="samples", selection=selected_atoms + ) + for output_name, atomic_property in atomic_properties.items(): if outputs[output_name].per_atom: @@ -507,7 +512,7 @@ def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: def get_checkpoint(self) -> Dict: checkpoint = { - "architecture_name": "dpa3", + "architecture_name": "experimental.dpa3", "model_ckpt_version": self.__checkpoint_version__, "metadata": self.metadata, "model_data": { diff --git a/src/metatrain/experimental/dpa3/schema-hypers.json b/src/metatrain/experimental/dpa3/schema-hypers.json new file mode 100644 index 0000000000..7108361de5 --- /dev/null +++ b/src/metatrain/experimental/dpa3/schema-hypers.json @@ -0,0 +1,149 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "name": { + "type": "string", + "enum": ["experimental.dpa3"] + }, + "model": { + "type": "object", + "properties": { + "type_map": { + "type": "array", + "items": { + "type": "string" + } + }, + "descriptor": { + "type": "object", + "properties": { + "repflow": { + "type": "object", + "properties": { + "n_dim": { "type": "integer" }, + "e_dim": { "type": "integer" }, + "a_dim": { "type": "integer" }, + "nlayers": { "type": "integer" }, + "e_rcut": { "type": "number" }, + "e_rcut_smth": { "type": "number" }, + "e_sel": { "type": "integer" }, + "a_rcut": { "type": "number" }, + "a_rcut_smth": { "type": "number" }, + "a_sel": { "type": "integer" }, + "axis_neuron": { "type": "integer" }, + "skip_stat": { "type": "boolean" }, + "a_compress_rate": { "type": "number" }, + "a_compress_e_rate": { "type": "number" }, + "a_compress_use_split": { "type": "boolean" }, + "update_angle": { "type": "boolean" }, + "update_style": { "type": "string" }, + "update_residual": { "type": "number" }, + "update_residual_init": { "type": "string" }, + "smooth_edge_update": { "type": "boolean" }, + "use_dynamic_sel": { "type": "boolean" }, + "sel_reduce_factor": { "type": "number" } + }, + "required": [ + "n_dim", "e_dim", "a_dim", "nlayers" + ] + }, + "activation_function": { "type": "string" }, + "use_tebd_bias": { "type": "boolean" }, + "precision": { "type": "string" }, + "concat_output_tebd": { "type": "boolean" } + } + }, + "fitting_net": { + "type": "object", + "properties": { + "neuron": { + "type": "array", + "items": { "type": "integer" } + }, + "resnet_dt": { "type": "boolean" }, + "seed": { "type": "integer" }, + "precision": { "type": "string" }, + "activation_function": { "type": "string" }, + "type": { "type": "string" }, + "numb_fparam": { "type": "integer" }, + "numb_aparam": { "type": "integer" }, + "dim_case_embd": { "type": "integer" }, + "trainable": { "type": "boolean" }, + "rcond": { "type": ["number", "null"] }, + "atom_ener": { + "type": "array", + "items": { "type": "number" } + }, + "use_aparam_as_mask": { "type": "boolean" } + } + }, + "cutoff_width": { + "type": "object", + "properties": { + "d_pet": { "type": "integer" }, + "d_head": { "type": "integer" }, + "d_feedforward": { "type": "integer" }, + "num_heads": { "type": "integer" }, + "num_attention_layers": { "type": "integer" }, + "num_gnn_layers": { "type": "integer" }, + "zbl": { "type": "boolean" }, + "long_range": { + "type": "object", + "properties": { + "enable": { "type": "boolean" }, + "use_ewald": { "type": "boolean" }, + "smearing": { "type": "number" }, + "kspace_resolution": { "type": "number" }, + "interpolation_nodes": { "type": "integer" } + } + } + } + } + }, + "additionalProperties": false + }, + "training": { + "type": "object", + "properties": { + "distributed": { "type": "boolean" }, + "distributed_port": { "type": "integer" }, + "batch_size": { "type": "integer" }, + "num_epochs": { "type": "integer" }, + "learning_rate": { "type": "number" }, + "early_stopping_patience": { "type": "integer" }, + "scheduler_patience": { "type": "integer" }, + "scheduler_factor": { "type": "number" }, + "log_interval": { "type": "integer" }, + "checkpoint_interval": { "type": "integer" }, + "scale_targets": { "type": "boolean" }, + "fixed_composition_weights": { + "type": "object", + "patternProperties": { + "^.*$": { + "type": "object", + "propertyNames": { + "pattern": "^[0-9]+$" + }, + "additionalProperties": { "type": "number" } + } + }, + "additionalProperties": false + }, + "per_structure_targets": { + "type": "array", + "items": { "type": "string" } + }, + "log_mae": { "type": "boolean" }, + "log_separate_blocks": { "type": "boolean" }, + "best_model_metric": { + "type": "string", + "enum": ["rmse_prod", "mae_prod", "loss"] + }, + "loss": { "type": "object" }, + "additionalProperties": false + } + } + }, + "additionalProperties": false +} \ No newline at end of file diff --git a/src/metatrain/experimental/dpa3/tests/__init__.py b/src/metatrain/experimental/dpa3/tests/__init__.py index b997a00c2e..f9f29d3a06 100644 --- a/src/metatrain/experimental/dpa3/tests/__init__.py +++ b/src/metatrain/experimental/dpa3/tests/__init__.py @@ -3,10 +3,10 @@ from metatrain.utils.architectures import get_default_hypers -DATASET_PATH = str(Path(__file__).parents[4] / "tests/resources/qm9_reduced_100.xyz") +DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz") DATASET_WITH_FORCES_PATH = str( - Path(__file__).parents[4] / "tests/resources/carbon_reduced_100.xyz" + Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz" ) -DEFAULT_HYPERS = get_default_hypers("dpa3") +DEFAULT_HYPERS = get_default_hypers("experimental.dpa3") MODEL_HYPERS = DEFAULT_HYPERS["model"] diff --git a/src/metatrain/experimental/dpa3/tests/test_checkpoints.py b/src/metatrain/experimental/dpa3/tests/test_checkpoints.py index de97c89e58..5c60934413 100644 --- a/src/metatrain/experimental/dpa3/tests/test_checkpoints.py +++ b/src/metatrain/experimental/dpa3/tests/test_checkpoints.py @@ -3,7 +3,7 @@ import pytest import torch -from metatrain.dpa3 import DPA3, Trainer +from metatrain.experimental.dpa3 import DPA3, Trainer from metatrain.utils.data import ( DatasetInfo, get_atomic_types, diff --git a/src/metatrain/experimental/dpa3/tests/test_continue.py b/src/metatrain/experimental/dpa3/tests/test_continue.py index 8bdb951da6..a34c2050b6 100644 --- a/src/metatrain/experimental/dpa3/tests/test_continue.py +++ b/src/metatrain/experimental/dpa3/tests/test_continue.py @@ -4,7 +4,7 @@ import torch from omegaconf import OmegaConf -from metatrain.dpa3 import DPA3, Trainer +from metatrain.experimental.dpa3 import DPA3, Trainer from metatrain.utils.data import Dataset, DatasetInfo from metatrain.utils.data.readers import read_systems, read_targets from metatrain.utils.data.target_info import get_energy_target_info diff --git a/src/metatrain/experimental/dpa3/tests/test_functionality.py b/src/metatrain/experimental/dpa3/tests/test_functionality.py new file mode 100644 index 0000000000..142f075035 --- /dev/null +++ b/src/metatrain/experimental/dpa3/tests/test_functionality.py @@ -0,0 +1,273 @@ +import metatensor.torch as mts +import pytest +import torch +from jsonschema.exceptions import ValidationError +from metatomic.torch import ModelOutput, System +from omegaconf import OmegaConf + +from metatrain.experimental.dpa3 import DPA3 +from metatrain.utils.architectures import check_architecture_options +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import ( + get_energy_target_info, + get_generic_target_info, +) +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + +from . import DEFAULT_HYPERS, MODEL_HYPERS + + +def test_prediction(): + """Tests the basic functionality of the forward pass of the model.""" + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + + system = System( + types=torch.tensor([6, 6]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + outputs = {"energy": ModelOutput(per_atom=False)} + model([system, system], outputs) + + +def test_dpa3_padding(): + """Tests that the model predicts the same energy independently of the + padding size.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + + system = System( + types=torch.tensor([6, 6]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + outputs = {"energy": ModelOutput(per_atom=False)} + lone_output = model([system], outputs) + + system_2 = System( + types=torch.tensor([6, 6, 6, 6, 6, 6, 6]), + positions=torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 2.0], + [0.0, 0.0, 3.0], + [0.0, 0.0, 4.0], + [0.0, 0.0, 5.0], + [0.0, 0.0, 6.0], + ] + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system_2 = get_system_with_neighbor_lists( + system_2, model.requested_neighbor_lists() + ) + padded_output = model([system, system_2], outputs) + + lone_energy = lone_output["energy"].block().values.squeeze(-1)[0] + padded_energy = padded_output["energy"].block().values.squeeze(-1)[0] + + assert torch.allclose(lone_energy, padded_energy, atol=1e-6, rtol=1e-6) + + +def test_prediction_subset_elements(): + """Tests that the model can predict on a subset of the elements it was trained + on.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + + system = System( + types=torch.tensor([6, 6]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + model( + [system], + {"energy": model.outputs["energy"]}, + ) + + +def test_prediction_subset_atoms(): + """Tests that the model can predict on a subset + of the atoms in a system.""" + + # we need float64 for this test, then we will change it back at the end + default_dtype_before = torch.get_default_dtype() + torch.set_default_dtype(torch.float64) + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + + # Since we don't yet support atomic predictions, we will test this by + # predicting on a system with two monomers at a large distance + + system_monomer = System( + types=torch.tensor([7, 8, 8]), + positions=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]], + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system_monomer = get_system_with_neighbor_lists( + system_monomer, model.requested_neighbor_lists() + ) + + energy_monomer = model( + [system_monomer], + {"energy": ModelOutput(per_atom=False)}, + ) + + system_far_away_dimer = System( + types=torch.tensor([7, 7, 8, 8, 8, 8]), + positions=torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.0, 50.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 2.0], + [0.0, 51.0, 0.0], + [0.0, 42.0, 0.0], + ], + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system_far_away_dimer = get_system_with_neighbor_lists( + system_far_away_dimer, model.requested_neighbor_lists() + ) + + selection_labels = mts.Labels( + names=["system", "atom"], + values=torch.tensor([[0, 0], [0, 2], [0, 3]]), + ) + + energy_dimer = model( + [system_far_away_dimer], + {"energy": ModelOutput(per_atom=False)}, + ) + + energy_monomer_in_dimer = model( + [system_far_away_dimer], + {"energy": ModelOutput(per_atom=False)}, + selected_atoms=selection_labels, + ) + + assert not mts.allclose(energy_monomer["energy"], energy_dimer["energy"]) + + assert mts.allclose(energy_monomer["energy"], energy_monomer_in_dimer["energy"]) + + torch.set_default_dtype(default_dtype_before) + + +def test_output_per_atom(): + """Tests that the model can output per-atom quantities.""" + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + + system = System( + types=torch.tensor([6, 1, 8, 7]), + positions=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]], + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + outputs = model( + [system], + {"energy": model.outputs["energy"]}, + ) + + assert outputs["energy"].block().samples.names == ["system", "atom"] + assert outputs["energy"].block().values.shape == (4, 1) + + +def test_fixed_composition_weights(): + """Tests the correctness of the json schema for fixed_composition_weights""" + + hypers = DEFAULT_HYPERS.copy() + hypers["training"]["fixed_composition_weights"] = { + "energy": { + 1: 1.0, + 6: 0.0, + 7: 0.0, + 8: 0.0, + 9: 3000.0, + } + } + hypers = OmegaConf.create(hypers) + check_architecture_options(name="experimental.dpa3", options=OmegaConf.to_container(hypers)) + + + + +def test_pet_single_atom(): + """Tests that the model predicts correctly on a single atom.""" + + dataset_info = DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6, 7, 8], + targets={ + "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) + }, + ) + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") + + system = System( + types=torch.tensor([6]), + positions=torch.tensor([[0.0, 0.0, 1.0]]), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + outputs = {"energy": ModelOutput(per_atom=False)} + model([system], outputs) + + diff --git a/src/metatrain/experimental/dpa3/tests/test_regression.py b/src/metatrain/experimental/dpa3/tests/test_regression.py index a76d34bdf8..4c45795d82 100644 --- a/src/metatrain/experimental/dpa3/tests/test_regression.py +++ b/src/metatrain/experimental/dpa3/tests/test_regression.py @@ -5,7 +5,7 @@ from metatomic.torch import ModelOutput from omegaconf import OmegaConf -from metatrain.dpa3 import DPA3, Trainer +from metatrain.experimental.dpa3 import DPA3, Trainer from metatrain.utils.data import Dataset, DatasetInfo from metatrain.utils.data.readers import ( read_systems, diff --git a/src/metatrain/experimental/dpa3/tests/test_torchscript.py b/src/metatrain/experimental/dpa3/tests/test_torchscript.py index 725debc45a..c0b51cacc5 100644 --- a/src/metatrain/experimental/dpa3/tests/test_torchscript.py +++ b/src/metatrain/experimental/dpa3/tests/test_torchscript.py @@ -3,7 +3,7 @@ import torch from metatomic.torch import System -from metatrain.dpa3 import DPA3 +from metatrain.experimental.dpa3 import DPA3 from metatrain.utils.data import DatasetInfo from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists diff --git a/src/metatrain/experimental/dpa3/trainer.py b/src/metatrain/experimental/dpa3/trainer.py index 7564df9489..1e7acd5a56 100644 --- a/src/metatrain/experimental/dpa3/trainer.py +++ b/src/metatrain/experimental/dpa3/trainer.py @@ -504,7 +504,7 @@ def train( def save_checkpoint(self, model, path: Union[str, Path]): checkpoint = { - "architecture_name": "dpa3", + "architecture_name": "experimental.dpa3", "model_ckpt_version": model.__checkpoint_version__, "trainer_ckpt_version": self.__checkpoint_version__, "metadata": model.__default_metadata__, diff --git a/tox.ini b/tox.ini index e5ab2e5d3b..8742cba893 100644 --- a/tox.ini +++ b/tox.ini @@ -11,6 +11,7 @@ envlist = pet-tests nanopet-tests deprecated-pet-tests + dpa3-tests [testenv] package = editable @@ -135,6 +136,16 @@ changedir = src/metatrain/pet/tests/ commands = pytest {posargs} +[testenv:dpa3-tests] +description = Run DPA3 tests with pytest +passenv = * +deps = + pytest +extras = dpa3 +changedir = src/metatrain/experimental/dpa3/tests/ +commands = + pytest {posargs} + [testenv:gap-tests] description = Run GAP tests with pytest passenv = * From 243fc48195a98c1cdac4833a1e1ec89d36adf304 Mon Sep 17 00:00:00 2001 From: liwentao Date: Wed, 24 Sep 2025 04:20:26 +0000 Subject: [PATCH 05/18] DPA3 review revision 0924 --- docs/static/qm9/options.yaml | 4 +- examples/basic_usage/run_dpa3.sh | 2 +- src/metatrain/experimental/dpa3/__init__.py | 11 ++- .../experimental/dpa3/default-hypers.yaml | 4 +- src/metatrain/experimental/dpa3/model.py | 73 +++++-------------- src/metatrain/experimental/dpa3/trainer.py | 9 ++- 6 files changed, 39 insertions(+), 64 deletions(-) diff --git a/docs/static/qm9/options.yaml b/docs/static/qm9/options.yaml index b32a40d94f..c0c0ddf7bf 100644 --- a/docs/static/qm9/options.yaml +++ b/docs/static/qm9/options.yaml @@ -1,8 +1,8 @@ # architecture used to train the model architecture: - name: experimental.dpa3 + name: soap_bpnn training: - num_epochs: 2 # a very short training run + num_epochs: 5 # a very short training run batch_size: 10 # Mandatory section defining the parameters for system and target data of the diff --git a/examples/basic_usage/run_dpa3.sh b/examples/basic_usage/run_dpa3.sh index 98e3a8ea8b..325168d77f 100644 --- a/examples/basic_usage/run_dpa3.sh +++ b/examples/basic_usage/run_dpa3.sh @@ -1,6 +1,6 @@ export METATENSOR_DEBUG_EXTENSIONS_LOADING=1 -# mtt train options.yaml +mtt train options.yaml package_dir=$(python -c "import site; print(site.getsitepackages()[0])") cp $package_dir/deepmd/lib/*.so extensions/deepmd/lib/ diff --git a/src/metatrain/experimental/dpa3/__init__.py b/src/metatrain/experimental/dpa3/__init__.py index fdd4575b92..66be2be6cb 100644 --- a/src/metatrain/experimental/dpa3/__init__.py +++ b/src/metatrain/experimental/dpa3/__init__.py @@ -3,4 +3,13 @@ __model__ = DPA3 -__trainer__ = Trainer \ No newline at end of file +__trainer__ = Trainer + +__authors__ = [ + ("Duo Zhang ", "@duozhang"), +] + +__maintainers__ = [ + ("Duo Zhang ", "@duozhang"), + ("Wentao Li ", "@wentaoli"), +] diff --git a/src/metatrain/experimental/dpa3/default-hypers.yaml b/src/metatrain/experimental/dpa3/default-hypers.yaml index 871b5e7585..91b51e2607 100644 --- a/src/metatrain/experimental/dpa3/default-hypers.yaml +++ b/src/metatrain/experimental/dpa3/default-hypers.yaml @@ -33,7 +33,7 @@ architecture: sel_reduce_factor: 10.0 activation_function: "custom_silu:10.0" use_tebd_bias: false - precision: "float64" + precision: "float32" concat_output_tebd: false fitting_net: neuron: @@ -42,7 +42,7 @@ architecture: - 240 resnet_dt: true seed: 1 - precision: "float64" + precision: "float32" activation_function: "custom_silu:10.0" type: "ener" numb_fparam: 0 diff --git a/src/metatrain/experimental/dpa3/model.py b/src/metatrain/experimental/dpa3/model.py index 2ffd0c56d7..f8112f99b8 100644 --- a/src/metatrain/experimental/dpa3/model.py +++ b/src/metatrain/experimental/dpa3/model.py @@ -15,7 +15,7 @@ ) from metatrain.utils.abc import ModelInterface -from metatrain.utils.additive import ZBL, OldCompositionModel +from metatrain.utils.additive import ZBL,CompositionModel from metatrain.utils.data import TargetInfo from metatrain.utils.data.dataset import DatasetInfo from metatrain.utils.dtype import dtype_to_str @@ -87,7 +87,7 @@ def concatenate_structures( class DPA3(ModelInterface): __checkpoint_version__ = 1 __supported_devices__ = ["cuda", "cpu"] - __supported_dtypes__ = [torch.float64, torch.float32] + __supported_dtypes__ = [torch.float32,torch.float64] __default_metadata__ = ModelMetadata( references={ "implementation": [ @@ -104,6 +104,14 @@ class DPA3(ModelInterface): def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: super().__init__(hypers, dataset_info, self.__default_metadata__) self.atomic_types = dataset_info.atomic_types + self.dtype = self.hypers["descriptor"]["precision"] + + if self.dtype == "float64": + self.dtype = torch.float64 + elif self.dtype == "float32": + self.dtype = torch.float32 + else: + raise ValueError(f"Unsupported precision: {self.dtype}") self.requested_nl = NeighborListOptions( cutoff=self.hypers["descriptor"]["repflow"]["e_rcut"], @@ -113,7 +121,7 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: self.targets_keys = list(dataset_info.targets.keys())[0] self.model = get_standard_model(hypers) - + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) self.outputs = { "features": ModelOutput(unit="", per_atom=True) @@ -128,7 +136,7 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: for target_name, target in dataset_info.targets.items(): self._add_output(target_name, target) - composition_model = OldCompositionModel( + composition_model = CompositionModel( hypers={}, dataset_info=DatasetInfo( length_unit=dataset_info.length_unit, @@ -136,60 +144,13 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: targets={ target_name: target_info for target_name, target_info in dataset_info.targets.items() - if OldCompositionModel.is_valid_target(target_name, target_info) + if CompositionModel.is_valid_target(target_name, target_info) }, ), ) additive_models = [composition_model] self.additive_models = torch.nn.ModuleList(additive_models) - self.reverse_precision_dict ={ - torch.float16: "float16", - torch.float32: "float32", - torch.float64: "float64", - torch.int32: "int32", - torch.int64: "int64", - torch.bfloat16: "bfloat16", - torch.bool: "bool", - } - - def _input_type_cast( - self, - coord: torch.Tensor, - box: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - str, - ]: - """Cast the input data to global float type.""" - input_prec = self.reverse_precision_dict[coord.dtype] - - _lst: list[Optional[torch.Tensor]] = [ - vv.to(coord.dtype) if vv is not None else None - for vv in [box, fparam, aparam] - ] - box, fparam, aparam = _lst - if ( - input_prec - == self.reverse_precision_dict[self.global_pt_float_precision] - ): - return coord, box, fparam, aparam, input_prec - else: - pp = torch.float32 - return ( - coord.to(pp), - box.to(pp) if box is not None else None, - fparam.to(pp) if fparam is not None else None, - aparam.to(pp) if aparam is not None else None, - input_prec, - ) - - def _add_output(self, target_name: str, target: TargetInfo) -> None: self.num_properties[target_name] = {} @@ -226,10 +187,10 @@ def forward( outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: - device = systems[0].positions.device - position_dtype = systems[0].positions.dtype + atype_dtype = systems[0].types.dtype + system_indices = torch.concatenate( [ torch.full( @@ -285,7 +246,7 @@ def forward( system_index ) = concatenate_structures(systems) - positions = positions.to(position_dtype) + type_to_index = {atomic_type: idx for idx, atomic_type in enumerate(self.atomic_types)} type_to_index[-1] = -1 @@ -418,7 +379,7 @@ def restart(self, dataset_info: DatasetInfo) -> "DPA3": targets={ target_name: target_info for target_name, target_info in dataset_info.targets.items() - if OldCompositionModel.is_valid_target(target_name, target_info) + if CompositionModel.is_valid_target(target_name, target_info) }, ), ) diff --git a/src/metatrain/experimental/dpa3/trainer.py b/src/metatrain/experimental/dpa3/trainer.py index 1e7acd5a56..51afc9ff0b 100644 --- a/src/metatrain/experimental/dpa3/trainer.py +++ b/src/metatrain/experimental/dpa3/trainer.py @@ -61,7 +61,7 @@ def train( checkpoint_dir: str, ): assert dtype in DPA3.__supported_dtypes__ - + is_distributed = self.hypers["distributed"] if is_distributed: @@ -126,9 +126,12 @@ def train( additive_model.to(dtype=torch.float64) logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model train_datasets, model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, self.hypers["fixed_composition_weights"], ) @@ -311,6 +314,7 @@ def train( optimizer.zero_grad() model.to(device) systems, targets, extra_data = batch + systems, targets, extra_data = batch_to( systems, targets, extra_data, device=device ) @@ -338,7 +342,8 @@ def train( predictions, systems, per_structure_targets ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - + + train_loss_batch = loss_fn(predictions, targets) train_loss_batch.backward() From 96b6e11b0b79b8d101a7c11e0c7a86306174b8b9 Mon Sep 17 00:00:00 2001 From: liwentao Date: Wed, 24 Sep 2025 04:24:34 +0000 Subject: [PATCH 06/18] add code owner for DPA3 --- CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/CODEOWNERS b/CODEOWNERS index f2eccb2635..d0981292d4 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -6,3 +6,4 @@ **/pet @abmazitov **/gap @DavideTisi **/nanopet @frostedoyster +**/dpa3 @wentaoli \ No newline at end of file From 69895b3869519423b93a7a8364f0e485ebf4f398 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 26 Sep 2025 19:29:00 +0200 Subject: [PATCH 07/18] Lint and add documentation --- README.md | 3 +- .../advanced-concepts/auxiliary-outputs.rst | 14 +- .../fitting-generic-targets.rst | 5 + docs/src/architectures/dpa3.rst | 8 + docs/src/architectures/nanopet.rst | 2 +- pyproject.toml | 3 +- .../experimental/dpa3/default-hypers.yaml | 29 +- src/metatrain/experimental/dpa3/model.py | 189 ++++-------- .../experimental/dpa3/schema-hypers.json | 290 +++++++++++++----- .../dpa3/tests/test_checkpoints.py | 5 +- .../experimental/dpa3/tests/test_continue.py | 5 +- .../dpa3/tests/test_functionality.py | 13 +- .../dpa3/tests/test_regression.py | 20 +- .../dpa3/tests/test_torchscript.py | 4 +- src/metatrain/experimental/dpa3/trainer.py | 21 +- tests/resources/options.yaml | 9 +- tests/utils/test_architectures.py | 1 + 17 files changed, 350 insertions(+), 271 deletions(-) create mode 100644 docs/src/architectures/dpa3.rst diff --git a/README.md b/README.md index 5df816df25..2ca3c22a6c 100644 --- a/README.md +++ b/README.md @@ -37,10 +37,11 @@ model: | Name | Description | |--------------------------|--------------------------------------------------------------------------------------------------------------------------------------| | GAP | Sparse Gaussian Approximation Potential (GAP) using Smooth Overlap of Atomic Positions (SOAP). | -| PET | Point Edge Transformer (PET), interatomic machine learning potential | +| PET | Point Edge Transformer | | NanoPET *(experimental)* | Re-implementation of the original PET with slightly improved training and evaluation speed | | PET *(deprecated)* | Original implementation of the PET model used for prototyping, now deprecated in favor of the native `metatrain` PET implementation. | | SOAP BPNN | A Behler-Parrinello neural network with SOAP features | +| DPA3 | An invariant graph neural network based on line graph series representations | diff --git a/docs/src/advanced-concepts/auxiliary-outputs.rst b/docs/src/advanced-concepts/auxiliary-outputs.rst index a9ac13d268..cd3755b157 100644 --- a/docs/src/advanced-concepts/auxiliary-outputs.rst +++ b/docs/src/advanced-concepts/auxiliary-outputs.rst @@ -31,13 +31,13 @@ by one or more architectures in the library: The following table shows the architectures that support each of the auxiliary outputs: -+--------------------------------------------+-----------+------+-----+---------+ -| Auxiliary output | SOAP-BPNN | PET | GAP | NanoPET | -+--------------------------------------------+-----------+------+-----+---------+ -| ``mtt::aux::{target}_last_layer_features`` | Yes | Yes | No | Yes | -+--------------------------------------------+-----------+------+-----+---------+ -| ``features`` | Yes | Yes | No | Yes | -+--------------------------------------------+-----------+------+-----+---------+ ++--------------------------------------------+-----------+------+-----+---------+------+ +| Auxiliary output | SOAP-BPNN | PET | GAP | NanoPET | DPA3 | ++--------------------------------------------+-----------+------+-----+---------+------+ +| ``mtt::aux::{target}_last_layer_features`` | Yes | Yes | No | Yes | No | ++--------------------------------------------+-----------+------+-----+---------+------+ +| ``features`` | Yes | Yes | No | Yes | No | ++--------------------------------------------+-----------+------+-----+---------+------+ The following tables show the metadata that will be provided for each of the auxiliary outputs: diff --git a/docs/src/advanced-concepts/fitting-generic-targets.rst b/docs/src/advanced-concepts/fitting-generic-targets.rst index bb0701425e..73349fdb36 100644 --- a/docs/src/advanced-concepts/fitting-generic-targets.rst +++ b/docs/src/advanced-concepts/fitting-generic-targets.rst @@ -40,6 +40,11 @@ capabilities of the architectures in metatrain. - Yes - Yes - Only with ``rank=1`` (vectors) and ``rank=2`` (2D tensors) + * - DPA3 + - Energy, forces, virial + - Yes + - No + - No Preparing generic targets for reading by metatrain diff --git a/docs/src/architectures/dpa3.rst b/docs/src/architectures/dpa3.rst new file mode 100644 index 0000000000..6591f14a52 --- /dev/null +++ b/docs/src/architectures/dpa3.rst @@ -0,0 +1,8 @@ +.. _architecture-nanopet: + +DPA3 (experimental) +====================== + +.. warning:: + + This is an **experimental architecture**. You should not use it for anything important. diff --git a/docs/src/architectures/nanopet.rst b/docs/src/architectures/nanopet.rst index f4ec15949f..19065ea22f 100644 --- a/docs/src/architectures/nanopet.rst +++ b/docs/src/architectures/nanopet.rst @@ -5,7 +5,7 @@ NanoPET (experimental) .. warning:: - This is an **experimental model**. You should not use it for anything important. + This is an **experimental architecture**. You should not use it for anything important. This is a more user-friendly re-implementation of the original PET :footcite:p:`pozdnyakov_smooth_2023` (which lives in https://github.com/spozdn/pet), diff --git a/pyproject.toml b/pyproject.toml index d23c92cf51..b50807c8dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,8 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] dpa3 = [ - "deepmd-kit>=3.1.0" + "deepmd-kit>=3.1.0", + "torch>=2.7", ] soap-bpnn = [ "torch-spex>=0.1,<0.2", diff --git a/src/metatrain/experimental/dpa3/default-hypers.yaml b/src/metatrain/experimental/dpa3/default-hypers.yaml index 91b51e2607..8100eb8fde 100644 --- a/src/metatrain/experimental/dpa3/default-hypers.yaml +++ b/src/metatrain/experimental/dpa3/default-hypers.yaml @@ -1,13 +1,9 @@ architecture: name: experimental.dpa3 model: - type_map: - - "H" - - "C" - - "N" - - "O" + type_map: [H, C, N, O] descriptor: - type: "dpa3" + type: dpa3 repflow: n_dim: 128 e_dim: 64 @@ -25,26 +21,23 @@ architecture: a_compress_e_rate: 2 a_compress_use_split: true update_angle: true - update_style: "res_residual" + update_style: res_residual update_residual: 0.1 - update_residual_init: "const" + update_residual_init: const smooth_edge_update: true use_dynamic_sel: true sel_reduce_factor: 10.0 - activation_function: "custom_silu:10.0" + activation_function: custom_silu:10.0 use_tebd_bias: false - precision: "float32" + precision: float32 concat_output_tebd: false fitting_net: - neuron: - - 240 - - 240 - - 240 + neuron: [240, 240, 240] resnet_dt: true seed: 1 - precision: "float32" - activation_function: "custom_silu:10.0" - type: "ener" + precision: float32 + activation_function: custom_silu:10.0 + type: ener numb_fparam: 0 numb_aparam: 0 dim_case_embd: 0 @@ -72,4 +65,4 @@ architecture: loss: type: mse weights: {} - reduction: mean \ No newline at end of file + reduction: mean diff --git a/src/metatrain/experimental/dpa3/model.py b/src/metatrain/experimental/dpa3/model.py index f8112f99b8..54d7fcbfc0 100644 --- a/src/metatrain/experimental/dpa3/model.py +++ b/src/metatrain/experimental/dpa3/model.py @@ -1,10 +1,9 @@ from typing import Any, Dict, List, Literal, Optional -import copy + import metatensor.torch as mts import torch +from deepmd.pt.model.model import get_standard_model from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.learn.nn import Linear as LinearMap -from metatensor.torch.learn.nn import ModuleMap from metatomic.torch import ( AtomisticModel, ModelCapabilities, @@ -15,35 +14,17 @@ ) from metatrain.utils.abc import ModelInterface -from metatrain.utils.additive import ZBL,CompositionModel +from metatrain.utils.additive import CompositionModel from metatrain.utils.data import TargetInfo from metatrain.utils.data.dataset import DatasetInfo from metatrain.utils.dtype import dtype_to_str -from metatrain.utils.long_range import DummyLongRangeFeaturizer, LongRangeFeaturizer from metatrain.utils.metadata import merge_metadata from metatrain.utils.scaler import Scaler from metatrain.utils.sum_over_atoms import sum_over_atoms -from deepmd.pt.model.model import get_standard_model - -def update_v1_v2(state_dict): - # This if-statement is necessary to handle cases when - # best_model_state_dict and model_state_dict are the same. - # In that case, the both are updated within the first call of - # this function in the PET.update_checkpoint() method. - if ( - state_dict is not None - and "additive_models.0.model.type_to_index" not in state_dict - ): - state_dict["additive_models.0.model.type_to_index"] = state_dict.pop( - "additive_models.0.type_to_index" - ) - # Data processing -def concatenate_structures( - systems: List[System] -): +def concatenate_structures(systems: List[System]): device = systems[0].positions.device positions = [] species = [] @@ -58,28 +39,30 @@ def concatenate_structures( atom_nums.append(len(system.positions)) atom_index_list.append(torch.arange(start=0, end=len(system.positions))) system_index_list.append(torch.full((len(system.positions),), i)) - max_atom_num = max(atom_nums) + max_atom_num = max(atom_nums) atom_index = torch.cat(atom_index_list, dim=0).to(torch.int32).to(device) system_index = torch.cat(system_index_list, dim=0).to(torch.int32).to(device) - - positions = torch.zeros((len(systems), max_atom_num, 3), dtype=systems[0].positions.dtype) + + positions = torch.zeros( + (len(systems), max_atom_num, 3), dtype=systems[0].positions.dtype + ) species = torch.full((len(systems), max_atom_num), -1, dtype=systems[0].types.dtype) - cells = torch.stack([system.cell for system in systems]) # 形状为 [batch_size, 3, 3] 或相应的晶胞形状 - + cells = torch.stack( + [system.cell for system in systems] + ) # 形状为 [batch_size, 3, 3] 或相应的晶胞形状 for i, system in enumerate(systems): - positions[i, :len(system.positions)] = system.positions - species[i, :len(system.positions)] = system.types + positions[i, : len(system.positions)] = system.positions + species[i, : len(system.positions)] = system.types cells[i] = system.cell node_counter += len(system.positions) - return ( positions.to(device), species.to(device), cells.to(device), atom_index, - system_index + system_index, ) @@ -87,7 +70,7 @@ def concatenate_structures( class DPA3(ModelInterface): __checkpoint_version__ = 1 __supported_devices__ = ["cuda", "cpu"] - __supported_dtypes__ = [torch.float32,torch.float64] + __supported_dtypes__ = [torch.float32, torch.float64] __default_metadata__ = ModelMetadata( references={ "implementation": [ @@ -105,14 +88,14 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: super().__init__(hypers, dataset_info, self.__default_metadata__) self.atomic_types = dataset_info.atomic_types self.dtype = self.hypers["descriptor"]["precision"] - + if self.dtype == "float64": self.dtype = torch.float64 elif self.dtype == "float32": self.dtype = torch.float32 else: raise ValueError(f"Unsupported precision: {self.dtype}") - + self.requested_nl = NeighborListOptions( cutoff=self.hypers["descriptor"]["repflow"]["e_rcut"], full_list=True, @@ -121,15 +104,13 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: self.targets_keys = list(dataset_info.targets.keys())[0] self.model = get_standard_model(hypers) - + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) - self.outputs = { - "features": ModelOutput(unit="", per_atom=True) - } + self.outputs: Dict[str, ModelOutput] = {} self.single_label = Labels.single() self.num_properties: Dict[str, Dict[str, int]] = {} - + self.key_labels: Dict[str, Labels] = {} self.component_labels: Dict[str, List[List[Labels]]] = {} self.property_labels: Dict[str, List[Labels]] = {} @@ -151,13 +132,10 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: additive_models = [composition_model] self.additive_models = torch.nn.ModuleList(additive_models) - def _add_output(self, target_name: str, target: TargetInfo) -> None: + if not target.is_scalar: + raise ValueError("The DPA3 architecture can only predict scalars.") self.num_properties[target_name] = {} - ll_features_name = ( - f"mtt::aux::{target_name.replace('mtt::', '')}_last_layer_features" - ) - self.outputs[ll_features_name] = ModelOutput(per_atom=True) self.key_labels[target_name] = target.layout.keys self.component_labels[target_name] = [ block.components for block in target.layout.blocks() @@ -165,19 +143,18 @@ def _add_output(self, target_name: str, target: TargetInfo) -> None: self.property_labels[target_name] = [ block.properties for block in target.layout.blocks() ] - self.outputs[target_name] = ModelOutput( quantity=target.quantity, unit=target.unit, per_atom=True, ) - + def get_rcut(self): return self.model.atomic_model.get_rcut() - + def get_sel(self): return self.model.atomic_model.get_sel() - + def requested_neighbor_lists(self) -> List[NeighborListOptions]: return [self.requested_nl] @@ -188,19 +165,8 @@ def forward( selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: device = systems[0].positions.device - - atype_dtype = systems[0].types.dtype - system_indices = torch.concatenate( - [ - torch.full( - (len(system),), - i_system, - device=device, - ) - for i_system, system in enumerate(systems) - ], - ) + atype_dtype = systems[0].types.dtype if self.single_label.values.device != device: self.single_label = self.single_label.to(device) @@ -222,45 +188,26 @@ def forward( return_dict: Dict[str, TensorMap] = {} - sample_values = torch.stack( - [ - system_indices, - torch.concatenate( - [ - torch.arange( - len(system), - device=device, - ) - for system in systems - ], - ), - ], - dim=1, + (positions, species, cells, atom_index, system_index) = concatenate_structures( + systems ) - ( - positions, - species, - cells, - atom_index, - system_index - ) = concatenate_structures(systems) - - - type_to_index = {atomic_type: idx for idx, atomic_type in enumerate(self.atomic_types)} - type_to_index[-1] = -1 + type_to_index = { + atomic_type: idx for idx, atomic_type in enumerate(self.atomic_types) + } + type_to_index[-1] = -1 atype = torch.tensor( [[type_to_index[s.item()] for s in row] for row in species], - dtype=atype_dtype + dtype=atype_dtype, ).to(positions.device) atype = atype.to(atype_dtype) - + if torch.all(cells == 0).item(): box = None else: box = cells - + model_ret = self.model.forward_common( positions, atype, @@ -269,7 +216,7 @@ def forward( aparam=None, do_atomic_virial=False, ) - + if self.model.get_fitting_net() is not None: model_predict = {} model_predict["atom_energy"] = model_ret["energy"] @@ -284,32 +231,34 @@ def forward( else: model_predict = model_ret model_predict["updated_coord"] += positions - + atomic_properties: Dict[str, TensorMap] = {} blocks: List[TensorBlock] = [] - + system_col = system_index atom_col = atom_index - + values = torch.stack([system_col, atom_col], dim=0).transpose(0, 1) invariant_coefficients = Labels( - names=["system", "atom"], - values=values.to(device) + names=["system", "atom"], values=values.to(device) ) - mask = torch.abs(model_predict["atom_energy"]) > 1e-10 atomic_property_tensor = model_predict["atom_energy"][mask].unsqueeze(-1) - - blocks.append(TensorBlock( - values=atomic_property_tensor, - samples=invariant_coefficients, - components=self.component_labels[self.targets_keys][0], - properties=self.property_labels[self.targets_keys][0].to(device), - )) - - atomic_properties[self.targets_keys] = TensorMap(self.key_labels[self.targets_keys].to(device), blocks) - + + blocks.append( + TensorBlock( + values=atomic_property_tensor, + samples=invariant_coefficients, + components=self.component_labels[self.targets_keys][0], + properties=self.property_labels[self.targets_keys][0].to(device), + ) + ) + + atomic_properties[self.targets_keys] = TensorMap( + self.key_labels[self.targets_keys].to(device), blocks + ) + if selected_atoms is not None: for output_name, tmap in atomic_properties.items(): atomic_properties[output_name] = mts.slice( @@ -317,13 +266,12 @@ def forward( ) for output_name, atomic_property in atomic_properties.items(): - if outputs[output_name].per_atom: return_dict[output_name] = atomic_property else: # sum the atomic property to get the total property return_dict[output_name] = sum_over_atoms(atomic_property) - + if not self.training: # at evaluation, we also introduce the scaler and additive contributions return_dict = self.scaler(return_dict) @@ -342,9 +290,8 @@ def forward( return_dict[name], additive_contributions[name], ) - + return return_dict - def restart(self, dataset_info: DatasetInfo) -> "DPA3": # merge old and new dataset info @@ -432,10 +379,10 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # Additionally, the composition model contains some `TensorMap`s that cannot # be registered correctly with Pytorch. This funciton moves them: - + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_ranges = [self.hypers['descriptor']['repflow']['e_rcut']] + interaction_ranges = [self.hypers["descriptor"]["repflow"]["e_rcut"]] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): interaction_ranges.append(additive_model.cutoff_radius) @@ -458,19 +405,9 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: @classmethod def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: - if checkpoint["model_ckpt_version"] == 1: - checkpoints.update_v1_v2(checkpoint["model_state_dict"]) - checkpoints.update_v1_v2(checkpoint["best_model_state_dict"]) - checkpoint["model_ckpt_version"] = 2 - - if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: - raise RuntimeError( - f"Unable to upgrade the checkpoint: the checkpoint is using model " - f"version {checkpoint['model_ckpt_version']}, while the current " - f"model version is {cls.__checkpoint_version__}." - ) + # version is still one, there are no new versions return checkpoint - + def get_checkpoint(self) -> Dict: checkpoint = { "architecture_name": "experimental.dpa3", @@ -484,6 +421,6 @@ def get_checkpoint(self) -> Dict: "best_model_state_dict": None, } return checkpoint - + def supported_outputs(self) -> Dict[str, ModelOutput]: - return self.outputs \ No newline at end of file + return self.outputs diff --git a/src/metatrain/experimental/dpa3/schema-hypers.json b/src/metatrain/experimental/dpa3/schema-hypers.json index 7108361de5..d9a9f170cb 100644 --- a/src/metatrain/experimental/dpa3/schema-hypers.json +++ b/src/metatrain/experimental/dpa3/schema-hypers.json @@ -4,7 +4,9 @@ "properties": { "name": { "type": "string", - "enum": ["experimental.dpa3"] + "enum": [ + "experimental.dpa3" + ] }, "model": { "type": "object", @@ -21,37 +23,92 @@ "repflow": { "type": "object", "properties": { - "n_dim": { "type": "integer" }, - "e_dim": { "type": "integer" }, - "a_dim": { "type": "integer" }, - "nlayers": { "type": "integer" }, - "e_rcut": { "type": "number" }, - "e_rcut_smth": { "type": "number" }, - "e_sel": { "type": "integer" }, - "a_rcut": { "type": "number" }, - "a_rcut_smth": { "type": "number" }, - "a_sel": { "type": "integer" }, - "axis_neuron": { "type": "integer" }, - "skip_stat": { "type": "boolean" }, - "a_compress_rate": { "type": "number" }, - "a_compress_e_rate": { "type": "number" }, - "a_compress_use_split": { "type": "boolean" }, - "update_angle": { "type": "boolean" }, - "update_style": { "type": "string" }, - "update_residual": { "type": "number" }, - "update_residual_init": { "type": "string" }, - "smooth_edge_update": { "type": "boolean" }, - "use_dynamic_sel": { "type": "boolean" }, - "sel_reduce_factor": { "type": "number" } + "n_dim": { + "type": "integer" + }, + "e_dim": { + "type": "integer" + }, + "a_dim": { + "type": "integer" + }, + "nlayers": { + "type": "integer" + }, + "e_rcut": { + "type": "number" + }, + "e_rcut_smth": { + "type": "number" + }, + "e_sel": { + "type": "integer" + }, + "a_rcut": { + "type": "number" + }, + "a_rcut_smth": { + "type": "number" + }, + "a_sel": { + "type": "integer" + }, + "axis_neuron": { + "type": "integer" + }, + "skip_stat": { + "type": "boolean" + }, + "a_compress_rate": { + "type": "number" + }, + "a_compress_e_rate": { + "type": "number" + }, + "a_compress_use_split": { + "type": "boolean" + }, + "update_angle": { + "type": "boolean" + }, + "update_style": { + "type": "string" + }, + "update_residual": { + "type": "number" + }, + "update_residual_init": { + "type": "string" + }, + "smooth_edge_update": { + "type": "boolean" + }, + "use_dynamic_sel": { + "type": "boolean" + }, + "sel_reduce_factor": { + "type": "number" + } }, "required": [ - "n_dim", "e_dim", "a_dim", "nlayers" + "n_dim", + "e_dim", + "a_dim", + "nlayers" ] }, - "activation_function": { "type": "string" }, - "use_tebd_bias": { "type": "boolean" }, - "precision": { "type": "string" }, - "concat_output_tebd": { "type": "boolean" } + "activation_function": { + "type": "string" + }, + "use_tebd_bias": { + "type": "boolean" + }, + "precision": { + "type": "string" + }, + "concat_output_tebd": { + "type": "boolean" + } } }, "fitting_net": { @@ -59,43 +116,96 @@ "properties": { "neuron": { "type": "array", - "items": { "type": "integer" } - }, - "resnet_dt": { "type": "boolean" }, - "seed": { "type": "integer" }, - "precision": { "type": "string" }, - "activation_function": { "type": "string" }, - "type": { "type": "string" }, - "numb_fparam": { "type": "integer" }, - "numb_aparam": { "type": "integer" }, - "dim_case_embd": { "type": "integer" }, - "trainable": { "type": "boolean" }, - "rcond": { "type": ["number", "null"] }, + "items": { + "type": "integer" + } + }, + "resnet_dt": { + "type": "boolean" + }, + "seed": { + "type": "integer" + }, + "precision": { + "type": "string" + }, + "activation_function": { + "type": "string" + }, + "type": { + "type": "string" + }, + "numb_fparam": { + "type": "integer" + }, + "numb_aparam": { + "type": "integer" + }, + "dim_case_embd": { + "type": "integer" + }, + "trainable": { + "type": "boolean" + }, + "rcond": { + "type": [ + "number", + "null" + ] + }, "atom_ener": { "type": "array", - "items": { "type": "number" } + "items": { + "type": "number" + } }, - "use_aparam_as_mask": { "type": "boolean" } + "use_aparam_as_mask": { + "type": "boolean" + } } }, "cutoff_width": { "type": "object", "properties": { - "d_pet": { "type": "integer" }, - "d_head": { "type": "integer" }, - "d_feedforward": { "type": "integer" }, - "num_heads": { "type": "integer" }, - "num_attention_layers": { "type": "integer" }, - "num_gnn_layers": { "type": "integer" }, - "zbl": { "type": "boolean" }, + "d_pet": { + "type": "integer" + }, + "d_head": { + "type": "integer" + }, + "d_feedforward": { + "type": "integer" + }, + "num_heads": { + "type": "integer" + }, + "num_attention_layers": { + "type": "integer" + }, + "num_gnn_layers": { + "type": "integer" + }, + "zbl": { + "type": "boolean" + }, "long_range": { "type": "object", "properties": { - "enable": { "type": "boolean" }, - "use_ewald": { "type": "boolean" }, - "smearing": { "type": "number" }, - "kspace_resolution": { "type": "number" }, - "interpolation_nodes": { "type": "integer" } + "enable": { + "type": "boolean" + }, + "use_ewald": { + "type": "boolean" + }, + "smearing": { + "type": "number" + }, + "kspace_resolution": { + "type": "number" + }, + "interpolation_nodes": { + "type": "integer" + } } } } @@ -106,17 +216,39 @@ "training": { "type": "object", "properties": { - "distributed": { "type": "boolean" }, - "distributed_port": { "type": "integer" }, - "batch_size": { "type": "integer" }, - "num_epochs": { "type": "integer" }, - "learning_rate": { "type": "number" }, - "early_stopping_patience": { "type": "integer" }, - "scheduler_patience": { "type": "integer" }, - "scheduler_factor": { "type": "number" }, - "log_interval": { "type": "integer" }, - "checkpoint_interval": { "type": "integer" }, - "scale_targets": { "type": "boolean" }, + "distributed": { + "type": "boolean" + }, + "distributed_port": { + "type": "integer" + }, + "batch_size": { + "type": "integer" + }, + "num_epochs": { + "type": "integer" + }, + "learning_rate": { + "type": "number" + }, + "early_stopping_patience": { + "type": "integer" + }, + "scheduler_patience": { + "type": "integer" + }, + "scheduler_factor": { + "type": "number" + }, + "log_interval": { + "type": "integer" + }, + "checkpoint_interval": { + "type": "integer" + }, + "scale_targets": { + "type": "boolean" + }, "fixed_composition_weights": { "type": "object", "patternProperties": { @@ -125,25 +257,39 @@ "propertyNames": { "pattern": "^[0-9]+$" }, - "additionalProperties": { "type": "number" } + "additionalProperties": { + "type": "number" + } } }, "additionalProperties": false }, "per_structure_targets": { "type": "array", - "items": { "type": "string" } + "items": { + "type": "string" + } + }, + "log_mae": { + "type": "boolean" + }, + "log_separate_blocks": { + "type": "boolean" }, - "log_mae": { "type": "boolean" }, - "log_separate_blocks": { "type": "boolean" }, "best_model_metric": { "type": "string", - "enum": ["rmse_prod", "mae_prod", "loss"] + "enum": [ + "rmse_prod", + "mae_prod", + "loss" + ] + }, + "loss": { + "type": "object" }, - "loss": { "type": "object" }, "additionalProperties": false } } }, "additionalProperties": false -} \ No newline at end of file +} diff --git a/src/metatrain/experimental/dpa3/tests/test_checkpoints.py b/src/metatrain/experimental/dpa3/tests/test_checkpoints.py index 5c60934413..5cdaf5681b 100644 --- a/src/metatrain/experimental/dpa3/tests/test_checkpoints.py +++ b/src/metatrain/experimental/dpa3/tests/test_checkpoints.py @@ -17,6 +17,7 @@ from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS + # from pathlib import Path # from metatrain.utils.architectures import get_default_hypers @@ -66,7 +67,6 @@ def model_trainer(): # minimize the size of the checkpoint on disk hypers = copy.deepcopy(MODEL_HYPERS) - model = DPA3(hypers, dataset_info) @@ -126,9 +126,10 @@ def test_failed_checkpoint_upgrade(cls_type): with pytest.raises(RuntimeError, match=match): cls.upgrade_checkpoint(checkpoint) + if __name__ == "__main__": test_get_checkpoint("finetune") test_get_checkpoint("restart") test_get_checkpoint("export") test_failed_checkpoint_upgrade("model") - test_failed_checkpoint_upgrade("trainer") \ No newline at end of file + test_failed_checkpoint_upgrade("trainer") diff --git a/src/metatrain/experimental/dpa3/tests/test_continue.py b/src/metatrain/experimental/dpa3/tests/test_continue.py index a34c2050b6..a58e8514a4 100644 --- a/src/metatrain/experimental/dpa3/tests/test_continue.py +++ b/src/metatrain/experimental/dpa3/tests/test_continue.py @@ -16,6 +16,7 @@ from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS + def test_continue(monkeypatch, tmp_path): """Tests that a model can be checkpointed and loaded for a continuation of the training process""" @@ -100,7 +101,3 @@ def test_continue(monkeypatch, tmp_path): output_after = model_after(systems[:5], {"mtt::U0": model_after.outputs["mtt::U0"]}) assert metatensor.torch.allclose(output_before["mtt::U0"], output_after["mtt::U0"]) - -if __name__ == "__main__": - tmp_path = "/aisi/mnt/data_nas/liwentao/devel_workspace/metatrain/metatrain/src/metatrain/dpa3/tests/tmp" - test_continue(monkeypatch, tmp_path) \ No newline at end of file diff --git a/src/metatrain/experimental/dpa3/tests/test_functionality.py b/src/metatrain/experimental/dpa3/tests/test_functionality.py index 142f075035..5c39e41ab3 100644 --- a/src/metatrain/experimental/dpa3/tests/test_functionality.py +++ b/src/metatrain/experimental/dpa3/tests/test_functionality.py @@ -1,7 +1,5 @@ import metatensor.torch as mts -import pytest import torch -from jsonschema.exceptions import ValidationError from metatomic.torch import ModelOutput, System from omegaconf import OmegaConf @@ -10,7 +8,6 @@ from metatrain.utils.data import DatasetInfo from metatrain.utils.data.target_info import ( get_energy_target_info, - get_generic_target_info, ) from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists @@ -190,7 +187,7 @@ def test_prediction_subset_atoms(): {"energy": ModelOutput(per_atom=False)}, selected_atoms=selection_labels, ) - + assert not mts.allclose(energy_monomer["energy"], energy_dimer["energy"]) assert mts.allclose(energy_monomer["energy"], energy_monomer_in_dimer["energy"]) @@ -243,9 +240,9 @@ def test_fixed_composition_weights(): } } hypers = OmegaConf.create(hypers) - check_architecture_options(name="experimental.dpa3", options=OmegaConf.to_container(hypers)) - - + check_architecture_options( + name="experimental.dpa3", options=OmegaConf.to_container(hypers) + ) def test_pet_single_atom(): @@ -269,5 +266,3 @@ def test_pet_single_atom(): system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) outputs = {"energy": ModelOutput(per_atom=False)} model([system], outputs) - - diff --git a/src/metatrain/experimental/dpa3/tests/test_regression.py b/src/metatrain/experimental/dpa3/tests/test_regression.py index 4c45795d82..b5ac870a14 100644 --- a/src/metatrain/experimental/dpa3/tests/test_regression.py +++ b/src/metatrain/experimental/dpa3/tests/test_regression.py @@ -30,7 +30,7 @@ def test_regression_init(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets ) - model = DPA3(MODEL_HYPERS, dataset_info).to('cpu') + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") # Predict on the first five systems systems = read_systems(DATASET_PATH)[:5] @@ -42,7 +42,7 @@ def test_regression_init(): systems, {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, ) - + expected_output = torch.tensor( [ [-8.196924407132], @@ -50,13 +50,14 @@ def test_regression_init(): [-4.913442698461], [-6.568228343430], [-4.895156818840], - ],dtype=torch.float64 + ], + dtype=torch.float64, ) # if you need to change the hardcoded values: # torch.set_printoptions(precision=12) # print(output["mtt::U0"].block().values) - + torch.testing.assert_close(output["mtt::U0"].block().values, expected_output) @@ -95,7 +96,7 @@ def test_regression_energies_forces_train(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[6], targets=target_info_dict ) - model = DPA3(MODEL_HYPERS, dataset_info).to('cpu') + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") trainer = Trainer(hypers["training"]) trainer.train( model=model, @@ -114,19 +115,20 @@ def test_regression_energies_forces_train(): output = evaluate_model( model, systems[:5], targets=target_info_dict, is_training=False ) - + expected_output = torch.tensor( [ [-0.555318677882], [-0.569078342763], [-0.579769296313], [-0.518369165620], - [-0.556493731493] - ],dtype=torch.float64 + [-0.556493731493], + ], + dtype=torch.float64, ) expected_gradients_output = torch.tensor( - [-0.006725381569, 0.008463345547, 0.025475740380],dtype=torch.float64 + [-0.006725381569, 0.008463345547, 0.025475740380], dtype=torch.float64 ) # if you need to change the hardcoded values: diff --git a/src/metatrain/experimental/dpa3/tests/test_torchscript.py b/src/metatrain/experimental/dpa3/tests/test_torchscript.py index c0b51cacc5..4520193c1b 100644 --- a/src/metatrain/experimental/dpa3/tests/test_torchscript.py +++ b/src/metatrain/experimental/dpa3/tests/test_torchscript.py @@ -21,7 +21,7 @@ def test_torchscript(): "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) }, ) - + model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") system = System( types=torch.tensor([6, 1, 8, 7]), @@ -32,7 +32,7 @@ def test_torchscript(): pbc=torch.tensor([False, False, False]), ) system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) - + model = torch.jit.script(model) model( [system], diff --git a/src/metatrain/experimental/dpa3/trainer.py b/src/metatrain/experimental/dpa3/trainer.py index 51afc9ff0b..64ef8b9dca 100644 --- a/src/metatrain/experimental/dpa3/trainer.py +++ b/src/metatrain/experimental/dpa3/trainer.py @@ -61,7 +61,7 @@ def train( checkpoint_dir: str, ): assert dtype in DPA3.__supported_dtypes__ - + is_distributed = self.hypers["distributed"] if is_distributed: @@ -124,9 +124,9 @@ def train( # numerical errors in the composition weights, which can be very large). for additive_model in model.additive_models: additive_model.to(dtype=torch.float64) - + logging.info("Calculating composition weights") - + model.additive_models[0].train_model( # this is the composition model train_datasets, model.additive_models[1:], @@ -139,7 +139,7 @@ def train( logging.info("Calculating scaling weights") model.scaler.train_model( train_datasets, model.additive_models, treat_as_additive=True - ) + ) if is_distributed: model = DistributedDataParallel(model, device_ids=[device]) @@ -226,7 +226,7 @@ def train( ) ) val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) - + # Extract all the possible outputs and their gradients: train_targets = (model.module if is_distributed else model).dataset_info.targets outputs_list = [] @@ -314,7 +314,7 @@ def train( optimizer.zero_grad() model.to(device) systems, targets, extra_data = batch - + systems, targets, extra_data = batch_to( systems, targets, extra_data, device=device ) @@ -330,7 +330,7 @@ def train( systems, targets, extra_data = batch_to( systems, targets, extra_data, dtype=dtype ) - + predictions = evaluate_model( model, systems, @@ -342,8 +342,7 @@ def train( predictions, systems, per_structure_targets ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - - + train_loss_batch = loss_fn(predictions, targets) train_loss_batch.backward() @@ -356,7 +355,7 @@ def train( train_rmse_calculator.update(predictions, targets) if self.hypers["log_mae"]: train_mae_calculator.update(predictions, targets) - + finalized_train_info = train_rmse_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets, is_distributed=is_distributed, @@ -564,4 +563,4 @@ def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: f"version {checkpoint['trainer_ckpt_version']}, while the current " f"trainer version is {cls.__checkpoint_version__}." ) - return checkpoint \ No newline at end of file + return checkpoint diff --git a/tests/resources/options.yaml b/tests/resources/options.yaml index 3b01066f70..f176a29973 100644 --- a/tests/resources/options.yaml +++ b/tests/resources/options.yaml @@ -1,14 +1,7 @@ seed: 42 architecture: - name: soap_bpnn - training: - batch_size: 5 - num_epochs: 1 - model: - soap: - max_radial: 4 - max_angular: 2 + name: experimental.dpa3 training_set: systems: diff --git a/tests/utils/test_architectures.py b/tests/utils/test_architectures.py index 732d4f1e91..6d04d4ca08 100644 --- a/tests/utils/test_architectures.py +++ b/tests/utils/test_architectures.py @@ -28,6 +28,7 @@ def test_find_all_architectures(): assert "pet" in all_arches assert "soap_bpnn" in all_arches assert "experimental.nanopet" in all_arches + assert "experimental.dpa3" in all_arches assert "deprecated.pet" in all_arches assert "llpr" in all_arches From 5204317376733541a0903e5fed7fa938ddfbef81 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 26 Sep 2025 19:49:53 +0200 Subject: [PATCH 08/18] Fix tests --- .../dpa3/tests/test_checkpoints.py | 135 ------------------ .../experimental/dpa3/tests/test_continue.py | 4 +- .../dpa3/tests/test_functionality.py | 4 +- .../dpa3/tests/test_regression.py | 2 +- 4 files changed, 6 insertions(+), 139 deletions(-) delete mode 100644 src/metatrain/experimental/dpa3/tests/test_checkpoints.py diff --git a/src/metatrain/experimental/dpa3/tests/test_checkpoints.py b/src/metatrain/experimental/dpa3/tests/test_checkpoints.py deleted file mode 100644 index 5cdaf5681b..0000000000 --- a/src/metatrain/experimental/dpa3/tests/test_checkpoints.py +++ /dev/null @@ -1,135 +0,0 @@ -import copy - -import pytest -import torch - -from metatrain.experimental.dpa3 import DPA3, Trainer -from metatrain.utils.data import ( - DatasetInfo, - get_atomic_types, - get_dataset, -) -from metatrain.utils.data.target_info import get_energy_target_info -from metatrain.utils.testing.checkpoints import ( - checkpoint_did_not_change, - make_checkpoint_load_tests, -) - -from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS - - -# from pathlib import Path - -# from metatrain.utils.architectures import get_default_hypers - -# DATASET_PATH = str(Path(__file__).parents[4] / "tests/resources/qm9_reduced_100.xyz") -# DATASET_WITH_FORCES_PATH = str( -# Path(__file__).parents[4] / "tests/resources/carbon_reduced_100.xyz" -# ) - -# DEFAULT_HYPERS = get_default_hypers("dpa3") -# MODEL_HYPERS = DEFAULT_HYPERS["model"] - - -@pytest.fixture(scope="module") -def model_trainer(): - energy_target = { - "quantity": "energy", - "read_from": DATASET_PATH, - "reader": "ase", - "key": "U0", - "unit": "eV", - "type": "scalar", - "per_atom": False, - "num_subtargets": 1, - "forces": False, - "stress": False, - "virial": False, - } - - dataset, targets_info, _ = get_dataset( - { - "systems": { - "read_from": DATASET_PATH, - "reader": "ase", - }, - "targets": { - "energy": energy_target, - }, - } - ) - - dataset_info = DatasetInfo( - length_unit="", - atomic_types=get_atomic_types(dataset), - targets=targets_info, - ) - - # minimize the size of the checkpoint on disk - hypers = copy.deepcopy(MODEL_HYPERS) - - model = DPA3(hypers, dataset_info) - - hypers = copy.deepcopy(DEFAULT_HYPERS) - hypers["training"]["num_epochs"] = 1 - trainer = Trainer(hypers["training"]) - - trainer.train( - model, - dtype=model.__supported_dtypes__[0], - devices=[torch.device("cpu")], - train_datasets=[dataset], - val_datasets=[dataset], - checkpoint_dir="", - ) - - return model, trainer - - -test_checkpoint_did_not_change = checkpoint_did_not_change - -test_loading_old_checkpoints = make_checkpoint_load_tests(DEFAULT_HYPERS) - - -@pytest.mark.parametrize("context", ["finetune", "restart", "export"]) -def test_get_checkpoint(context): - """ - Test that the checkpoint created by the model.get_checkpoint() - function can be loaded back in all possible contexts. - """ - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={"energy": get_energy_target_info({"unit": "eV"})}, - ) - model = DPA3(MODEL_HYPERS, dataset_info) - checkpoint = model.get_checkpoint() - DPA3.load_checkpoint(checkpoint, context) - - -@pytest.mark.parametrize("cls_type", ["model", "trainer"]) -def test_failed_checkpoint_upgrade(cls_type): - """Test error raised when trying to upgrade an invalid checkpoint version.""" - checkpoint = {f"{cls_type}_ckpt_version": 9999} - - if cls_type == "model": - cls = DPA3 - version = DPA3.__checkpoint_version__ - else: - cls = Trainer - version = Trainer.__checkpoint_version__ - - match = ( - f"Unable to upgrade the checkpoint: the checkpoint is using {cls_type} version " - f"9999, while the current {cls_type} version is {version}." - ) - with pytest.raises(RuntimeError, match=match): - cls.upgrade_checkpoint(checkpoint) - - -if __name__ == "__main__": - test_get_checkpoint("finetune") - test_get_checkpoint("restart") - test_get_checkpoint("export") - test_failed_checkpoint_upgrade("model") - test_failed_checkpoint_upgrade("trainer") diff --git a/src/metatrain/experimental/dpa3/tests/test_continue.py b/src/metatrain/experimental/dpa3/tests/test_continue.py index a58e8514a4..3c0616997d 100644 --- a/src/metatrain/experimental/dpa3/tests/test_continue.py +++ b/src/metatrain/experimental/dpa3/tests/test_continue.py @@ -67,7 +67,7 @@ def test_continue(monkeypatch, tmp_path): trainer = Trainer(hypers["training"]) trainer.train( model=model, - dtype=torch.float64, + dtype=torch.float32, devices=[torch.device("cpu")], train_datasets=[dataset], val_datasets=[dataset], @@ -83,7 +83,7 @@ def test_continue(monkeypatch, tmp_path): trainer = Trainer(hypers["training"]) trainer.train( model=model_after, - dtype=torch.float64, + dtype=torch.float32, devices=[torch.device("cpu")], train_datasets=[dataset], val_datasets=[dataset], diff --git a/src/metatrain/experimental/dpa3/tests/test_functionality.py b/src/metatrain/experimental/dpa3/tests/test_functionality.py index 5c39e41ab3..8ec99deb87 100644 --- a/src/metatrain/experimental/dpa3/tests/test_functionality.py +++ b/src/metatrain/experimental/dpa3/tests/test_functionality.py @@ -190,7 +190,9 @@ def test_prediction_subset_atoms(): assert not mts.allclose(energy_monomer["energy"], energy_dimer["energy"]) - assert mts.allclose(energy_monomer["energy"], energy_monomer_in_dimer["energy"]) + assert mts.allclose( + energy_monomer["energy"], energy_monomer_in_dimer["energy"], atol=1e-6 + ) torch.set_default_dtype(default_dtype_before) diff --git a/src/metatrain/experimental/dpa3/tests/test_regression.py b/src/metatrain/experimental/dpa3/tests/test_regression.py index b5ac870a14..45764e59bc 100644 --- a/src/metatrain/experimental/dpa3/tests/test_regression.py +++ b/src/metatrain/experimental/dpa3/tests/test_regression.py @@ -100,7 +100,7 @@ def test_regression_energies_forces_train(): trainer = Trainer(hypers["training"]) trainer.train( model=model, - dtype=torch.float64, + dtype=torch.float32, devices=[torch.device("cpu")], train_datasets=[dataset], val_datasets=[dataset], From c536d42827597475dcdc476a380a10e9106d2d08 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 26 Sep 2025 20:01:11 +0200 Subject: [PATCH 09/18] Update to new loss code --- src/metatrain/experimental/dpa3/trainer.py | 122 ++++++++++----------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/src/metatrain/experimental/dpa3/trainer.py b/src/metatrain/experimental/dpa3/trainer.py index 64ef8b9dca..a31d69440f 100644 --- a/src/metatrain/experimental/dpa3/trainer.py +++ b/src/metatrain/experimental/dpa3/trainer.py @@ -20,10 +20,9 @@ ) from metatrain.utils.distributed.slurm import DistributedEnvironment from metatrain.utils.evaluate_model import evaluate_model -from metatrain.utils.external_naming import to_external_name from metatrain.utils.io import check_file_extension from metatrain.utils.logging import ROOT_LOGGER, MetricLogger -from metatrain.utils.loss import TensorMapDictLoss +from metatrain.utils.loss import LossAggregator from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, @@ -47,6 +46,7 @@ def __init__(self, hypers): self.optimizer_state_dict = None self.scheduler_state_dict = None self.epoch = None + self.best_epoch = None self.best_metric = None self.best_model_state_dict = None self.best_optimizer_state_dict = None @@ -76,7 +76,7 @@ def train( if len(devices) > 1: raise ValueError( "Requested distributed training with the `multi-gpu` device. " - " If you want to run distributed training with SOAP-BPNN, please " + " If you want to run distributed training with DPA3, please " "set `device` to cuda." ) # the calculation of the device number works both when GPUs on different @@ -126,7 +126,6 @@ def train( additive_model.to(dtype=torch.float64) logging.info("Calculating composition weights") - model.additive_models[0].train_model( # this is the composition model train_datasets, model.additive_models[1:], @@ -203,7 +202,7 @@ def train( collate_fn=collate_fn, ) ) - train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=False) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) # Create dataloader for the validation datasets: val_dataloaders = [] @@ -234,29 +233,23 @@ def train( outputs_list.append(target_name) for gradient_name in target_info.gradients: outputs_list.append(f"{target_name}_{gradient_name}_gradients") - # Create a loss weight dict: - loss_weights_dict = {} - for output_name in outputs_list: - loss_weights_dict[output_name] = ( - self.hypers["loss"]["weights"][ - to_external_name(output_name, train_targets) - ] - if to_external_name(output_name, train_targets) - in self.hypers["loss"]["weights"] - else 1.0 - ) - loss_weights_dict_external = { - to_external_name(key, train_targets): value - for key, value in loss_weights_dict.items() - } - loss_hypers = copy.deepcopy(self.hypers["loss"]) - loss_hypers["weights"] = loss_weights_dict - logging.info(f"Training with loss weights: {loss_weights_dict_external}") # Create a loss function: - loss_fn = TensorMapDictLoss( - **loss_hypers, + loss_hypers = self.hypers["loss"] + loss_fn = LossAggregator( + targets=train_targets, + config=loss_hypers, ) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") # Create an optimizer: optimizer = torch.optim.Adam( @@ -312,9 +305,8 @@ def train( for batch in train_dataloader: optimizer.zero_grad() - model.to(device) - systems, targets, extra_data = batch + systems, targets, extra_data = batch systems, targets, extra_data = batch_to( systems, targets, extra_data, device=device ) @@ -337,13 +329,14 @@ def train( {key: train_targets[key] for key in targets.keys()}, is_training=True, ) + # average by the number of atoms predictions = average_by_num_atoms( predictions, systems, per_structure_targets ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - train_loss_batch = loss_fn(predictions, targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) train_loss_batch.backward() optimizer.step() @@ -402,7 +395,7 @@ def train( ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - val_loss_batch = loss_fn(predictions, targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) if is_distributed: # sum the loss over all processes @@ -487,6 +480,7 @@ def train( self.best_model_state_dict = copy.deepcopy( (model.module if is_distributed else model).state_dict() ) + self.best_epoch = epoch self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) if epoch % self.hypers["checkpoint_interval"] == 0: @@ -505,26 +499,39 @@ def train( self.epoch = epoch self.optimizer_state_dict = optimizer.state_dict() self.scheduler_state_dict = lr_scheduler.state_dict() + checkpoint = model.get_checkpoint() + checkpoint.update( + { + "train_hypers": self.hypers, + "trainer_ckpt_version": self.__checkpoint_version__, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + + if is_distributed: + torch.distributed.destroy_process_group() def save_checkpoint(self, model, path: Union[str, Path]): - checkpoint = { - "architecture_name": "experimental.dpa3", - "model_ckpt_version": model.__checkpoint_version__, - "trainer_ckpt_version": self.__checkpoint_version__, - "metadata": model.__default_metadata__, - "model_data": { - "model_hypers": model.hypers, - "dataset_info": model.dataset_info, - }, - "model_state_dict": model.state_dict(), - "train_hypers": self.hypers, - "epoch": self.epoch, - "optimizer_state_dict": self.optimizer_state_dict, - "scheduler_state_dict": self.scheduler_state_dict, - "best_metric": self.best_metric, - "best_model_state_dict": self.best_model_state_dict, - "best_optimizer_state_dict": self.best_optimizer_state_dict, - } + checkpoint = model.get_checkpoint() + checkpoint.update( + { + "train_hypers": self.hypers, + "trainer_ckpt_version": self.__checkpoint_version__, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) torch.save( checkpoint, check_file_extension(path, ".ckpt"), @@ -537,21 +544,14 @@ def load_checkpoint( hypers: Dict[str, Any], context: Literal["restart", "finetune"], # not used at the moment ) -> "Trainer": - epoch = checkpoint["epoch"] - optimizer_state_dict = checkpoint["optimizer_state_dict"] - scheduler_state_dict = checkpoint["scheduler_state_dict"] - best_metric = checkpoint["best_metric"] - best_model_state_dict = checkpoint["best_model_state_dict"] - best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] - - # Create the trainer trainer = cls(hypers) - trainer.optimizer_state_dict = optimizer_state_dict - trainer.scheduler_state_dict = scheduler_state_dict - trainer.epoch = epoch - trainer.best_metric = best_metric - trainer.best_model_state_dict = best_model_state_dict - trainer.best_optimizer_state_dict = best_optimizer_state_dict + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] return trainer From c49063f0556297f0ebd815fbc4a8329419c68561 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 26 Sep 2025 20:06:03 +0200 Subject: [PATCH 10/18] Fix tests --- .../experimental/dpa3/tests/test_regression.py | 10 +++++----- tests/utils/test_architectures.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/metatrain/experimental/dpa3/tests/test_regression.py b/src/metatrain/experimental/dpa3/tests/test_regression.py index 45764e59bc..34211773a6 100644 --- a/src/metatrain/experimental/dpa3/tests/test_regression.py +++ b/src/metatrain/experimental/dpa3/tests/test_regression.py @@ -55,8 +55,8 @@ def test_regression_init(): ) # if you need to change the hardcoded values: - # torch.set_printoptions(precision=12) - # print(output["mtt::U0"].block().values) + torch.set_printoptions(precision=12) + print(output["mtt::U0"].block().values) torch.testing.assert_close(output["mtt::U0"].block().values, expected_output) @@ -132,9 +132,9 @@ def test_regression_energies_forces_train(): ) # if you need to change the hardcoded values: - # torch.set_printoptions(precision=12) - # print(output["energy"].block().values) - # print(output["energy"].block().gradient("positions").values.squeeze(-1)[0]) + torch.set_printoptions(precision=12) + print(output["energy"].block().values) + print(output["energy"].block().gradient("positions").values.squeeze(-1)[0]) torch.testing.assert_close(output["energy"].block().values, expected_output) torch.testing.assert_close( diff --git a/tests/utils/test_architectures.py b/tests/utils/test_architectures.py index 6d04d4ca08..b632e68244 100644 --- a/tests/utils/test_architectures.py +++ b/tests/utils/test_architectures.py @@ -22,7 +22,7 @@ def is_None(*args, **kwargs) -> None: def test_find_all_architectures(): all_arches = find_all_architectures() - assert len(all_arches) == 6 + assert len(all_arches) == 7 assert "gap" in all_arches assert "pet" in all_arches From e676aaa06c34e625dd9372238584f6fdacf678cc Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 26 Sep 2025 20:13:51 +0200 Subject: [PATCH 11/18] Fix tests --- .../dpa3/tests/test_regression.py | 22 +++++++++---------- tests/resources/options.yaml | 9 +++++++- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/metatrain/experimental/dpa3/tests/test_regression.py b/src/metatrain/experimental/dpa3/tests/test_regression.py index 34211773a6..20fb0ff6c7 100644 --- a/src/metatrain/experimental/dpa3/tests/test_regression.py +++ b/src/metatrain/experimental/dpa3/tests/test_regression.py @@ -45,11 +45,11 @@ def test_regression_init(): expected_output = torch.tensor( [ - [-8.196924407132], - [-6.523280256844], - [-4.913442698461], - [-6.568228343430], - [-4.895156818840], + [8.893970727921], + [7.150644659996], + [5.338875532150], + [7.145487308502], + [5.402073264122], ], dtype=torch.float64, ) @@ -118,17 +118,17 @@ def test_regression_energies_forces_train(): expected_output = torch.tensor( [ - [-0.555318677882], - [-0.569078342763], - [-0.579769296313], - [-0.518369165620], - [-0.556493731493], + [0.630174279213], + [0.653932452202], + [0.664113998413], + [0.590713620186], + [0.635889530182], ], dtype=torch.float64, ) expected_gradients_output = torch.tensor( - [-0.006725381569, 0.008463345547, 0.025475740380], dtype=torch.float64 + [0.006374867036, -0.008849388247, 0.030855362978], dtype=torch.float64 ) # if you need to change the hardcoded values: diff --git a/tests/resources/options.yaml b/tests/resources/options.yaml index f176a29973..3b01066f70 100644 --- a/tests/resources/options.yaml +++ b/tests/resources/options.yaml @@ -1,7 +1,14 @@ seed: 42 architecture: - name: experimental.dpa3 + name: soap_bpnn + training: + batch_size: 5 + num_epochs: 1 + model: + soap: + max_radial: 4 + max_angular: 2 training_set: systems: From 53b5044fb378a9cc11f81e83eada401ab4be9616 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 26 Sep 2025 20:30:42 +0200 Subject: [PATCH 12/18] Fix docs --- docs/src/architectures/dpa3.rst | 36 ++++++++++++++++++++++++++++++++ examples/basic_usage/run_dpa3.sh | 9 -------- 2 files changed, 36 insertions(+), 9 deletions(-) delete mode 100644 examples/basic_usage/run_dpa3.sh diff --git a/docs/src/architectures/dpa3.rst b/docs/src/architectures/dpa3.rst index 6591f14a52..2517ec8792 100644 --- a/docs/src/architectures/dpa3.rst +++ b/docs/src/architectures/dpa3.rst @@ -6,3 +6,39 @@ DPA3 (experimental) .. warning:: This is an **experimental architecture**. You should not use it for anything important. + +This is an interface to the DPA3 architecture described in https://arxiv.org/abs/2506.01686 +and implemented in deepmd-kit (https://github.com/deepmodeling/deepmd-kit). + +Installation +------------ + +To install the package, you can run the following command in the root +directory of the repository: + +.. code-block:: bash + + pip install metatrain[dpa3] + +This will install the package with the DPA3 dependencies. + + +Default Hyperparameters +----------------------- + +The default hyperparameters for the DPA3 architecture are: + +.. literalinclude:: ../../../src/metatrain/experimental/dpa3/default-hypers.yaml + :language: yaml + + +Tuning Hyperparameters +---------------------- + +@littlepeachs this is where you can tell users how to tune the parameters of the model +to obtain different speed/accuracy tradeoffs + +References +---------- + +.. footbibliography:: diff --git a/examples/basic_usage/run_dpa3.sh b/examples/basic_usage/run_dpa3.sh deleted file mode 100644 index 325168d77f..0000000000 --- a/examples/basic_usage/run_dpa3.sh +++ /dev/null @@ -1,9 +0,0 @@ -export METATENSOR_DEBUG_EXTENSIONS_LOADING=1 - -mtt train options.yaml - -package_dir=$(python -c "import site; print(site.getsitepackages()[0])") -cp $package_dir/deepmd/lib/*.so extensions/deepmd/lib/ -cp $package_dir/deepmd_kit.libs/*.so* extensions/deepmd/lib/ - -mtt eval model.pt eval.yaml -e extensions/ \ No newline at end of file From 94748938bd405ed3ee3374492576ad1192e66a3f Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 26 Sep 2025 20:36:33 +0200 Subject: [PATCH 13/18] Fix docs --- docs/src/architectures/dpa3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/architectures/dpa3.rst b/docs/src/architectures/dpa3.rst index 2517ec8792..b0f999d6a1 100644 --- a/docs/src/architectures/dpa3.rst +++ b/docs/src/architectures/dpa3.rst @@ -1,4 +1,4 @@ -.. _architecture-nanopet: +.. _architecture-dpa3: DPA3 (experimental) ====================== From 10fb6cc8e868cff14434773b5f4b0ed00ded617a Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Wed, 19 Nov 2025 15:43:05 +0100 Subject: [PATCH 14/18] Update dpa3 to follow with recent changes --- pyproject.toml | 3 +- .../experimental/dpa3/default-hypers.yaml | 68 ---- .../experimental/dpa3/documentation.py | 146 +++++++++ src/metatrain/experimental/dpa3/model.py | 8 +- .../experimental/dpa3/schema-hypers.json | 295 ------------------ src/metatrain/experimental/dpa3/trainer.py | 37 +-- 6 files changed, 169 insertions(+), 388 deletions(-) delete mode 100644 src/metatrain/experimental/dpa3/default-hypers.yaml create mode 100644 src/metatrain/experimental/dpa3/documentation.py delete mode 100644 src/metatrain/experimental/dpa3/schema-hypers.json diff --git a/pyproject.toml b/pyproject.toml index 21b0334c13..5195965ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,8 +65,7 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] dpa3 = [ - "deepmd-kit>=3.1.0", - "torch>=2.7", + "deepmd-kit[torch]>=3.1.0", ] soap-bpnn = [ "torch-spex>=0.1,<0.2", diff --git a/src/metatrain/experimental/dpa3/default-hypers.yaml b/src/metatrain/experimental/dpa3/default-hypers.yaml deleted file mode 100644 index 8100eb8fde..0000000000 --- a/src/metatrain/experimental/dpa3/default-hypers.yaml +++ /dev/null @@ -1,68 +0,0 @@ -architecture: - name: experimental.dpa3 - model: - type_map: [H, C, N, O] - descriptor: - type: dpa3 - repflow: - n_dim: 128 - e_dim: 64 - a_dim: 32 - nlayers: 6 - e_rcut: 6.0 - e_rcut_smth: 5.3 - e_sel: 1200 - a_rcut: 4.0 - a_rcut_smth: 3.5 - a_sel: 300 - axis_neuron: 4 - skip_stat: true - a_compress_rate: 1 - a_compress_e_rate: 2 - a_compress_use_split: true - update_angle: true - update_style: res_residual - update_residual: 0.1 - update_residual_init: const - smooth_edge_update: true - use_dynamic_sel: true - sel_reduce_factor: 10.0 - activation_function: custom_silu:10.0 - use_tebd_bias: false - precision: float32 - concat_output_tebd: false - fitting_net: - neuron: [240, 240, 240] - resnet_dt: true - seed: 1 - precision: float32 - activation_function: custom_silu:10.0 - type: ener - numb_fparam: 0 - numb_aparam: 0 - dim_case_embd: 0 - trainable: true - rcond: null - atom_ener: [] - use_aparam_as_mask: false - training: - distributed: false - distributed_port: 39591 - batch_size: 8 - num_epochs: 100 - learning_rate: 0.001 - early_stopping_patience: 200 - scheduler_patience: 100 - scheduler_factor: 0.8 - log_interval: 1 - checkpoint_interval: 25 - scale_targets: true - fixed_composition_weights: {} - per_structure_targets: [] - log_mae: false - log_separate_blocks: false - best_model_metric: rmse_prod - loss: - type: mse - weights: {} - reduction: mean diff --git a/src/metatrain/experimental/dpa3/documentation.py b/src/metatrain/experimental/dpa3/documentation.py new file mode 100644 index 0000000000..8e2cfc8a82 --- /dev/null +++ b/src/metatrain/experimental/dpa3/documentation.py @@ -0,0 +1,146 @@ +""" +DPA3 +==== +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.loss import LossSpecification + + +########################### +# MODEL HYPERPARAMETERS # +########################### + + +class RepflowHypers(TypedDict): + n_dim: int = 128 + e_dim: int = 64 + a_dim: int = 32 + nlayers: int = 6 + e_rcut: float = 6.0 + e_rcut_smth: float = 5.3 + e_sel: int = 1200 + a_rcut: float = 4.0 + a_rcut_smth: float = 3.5 + a_sel: int = 300 + axis_neuron: int = 4 + skip_stat: bool = True + a_compress_rate: int = 1 + a_compress_e_rate: int = 2 + a_compress_use_split: bool = True + update_angle: bool = True + # TODO: what are the options here + update_style: str = "res_residual" + update_residual: float = 0.1 + # TODO: what are the options here + update_residual_init: str = "const" + smooth_edge_update: bool = True + use_dynamic_sel: bool = True + sel_reduce_factor: float = 10.0 + + +class DescriptorHypers(TypedDict): + # TODO: what are the options here + type: str = "dpa3" + repflow: RepflowHypers = init_with_defaults(RepflowHypers) + # TODO: what are the options here + activation_function: str = "custom_silu:10.0" + use_tebd_bias: bool = False + # TODO: what are the options here + precision: str = "float32" + concat_output_tebd: bool = False + + +class FittingNetHypers(TypedDict): + neuron: list[int] = [240, 240, 240] + resnet_dt: bool = True + seed: int = 1 + # TODO: what are the options here + precision: str = "float32" + # TODO: what are the options here + activation_function: str = "custom_silu:10.0" + # TODO: what are the options here + type: str = "ener" + numb_fparam: int = 0 + numb_aparam: int = 0 + dim_case_embd: int = 0 + trainable: bool = True + rcond: Optional[float] = None + atom_ener: list[float] = [] + use_aparam_as_mask: bool = False + + +class ModelHypers(TypedDict): + """Hyperparameters for the DPA3 model.""" + + type_map: list[str] = ["H", "C", "N", "O"] + + descriptor: DescriptorHypers = init_with_defaults(DescriptorHypers) + fitting_net: FittingNetHypers = init_with_defaults(FittingNetHypers) + + +############################## +# TRAINER HYPERPARAMETERS # +############################## + + +class TrainerHypers(TypedDict): + distributed: bool = False + """Whether to use distributed training""" + distributed_port: int = 39591 + """Port for DDP communication""" + batch_size: int = 8 + """The number of samples to use in each batch of training. This + hyperparameter controls the tradeoff between training speed and memory usage. In + general, larger batch sizes will lead to faster training, but might require more + memory.""" + num_epochs: int = 100 + """Number of epochs.""" + learning_rate: float = 0.001 + """Learning rate.""" + + # TODO: update the scheduler or not + scheduler_patience: int = 100 + scheduler_factor: float = 0.8 + + log_interval: int = 1 + """Interval to log metrics.""" + checkpoint_interval: int = 100 + """Interval to save checkpoints.""" + scale_targets: bool = True + """Normalize targets to unit std during training.""" + fixed_composition_weights: FixedCompositionWeights = {} + """Weights for atomic contributions. + + This is passed to the ``fixed_weights`` argument of + :meth:`CompositionModel.train_model + `, + see its documentation to understand exactly what to pass here. + """ + # fixed_scaling_weights: FixedScalerWeights = {} + # """Weights for target scaling. + + # This is passed to the ``fixed_weights`` argument of + # :meth:`Scaler.train_model `, + # see its documentation to understand exactly what to pass here. + # """ + per_structure_targets: list[str] = [] + """Targets to calculate per-structure losses.""" + # num_workers: Optional[int] = None + # """Number of workers for data loading. If not provided, it is set + # automatically.""" + log_mae: bool = False + """Log MAE alongside RMSE""" + log_separate_blocks: bool = False + """Log per-block error.""" + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "rmse_prod" + """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" + + loss: str | dict[str, LossSpecification] = "mse" + """This section describes the loss function to be used. See the + :ref:`loss-functions` for more details.""" diff --git a/src/metatrain/experimental/dpa3/model.py b/src/metatrain/experimental/dpa3/model.py index 54d7fcbfc0..6b582c6fe0 100644 --- a/src/metatrain/experimental/dpa3/model.py +++ b/src/metatrain/experimental/dpa3/model.py @@ -22,6 +22,8 @@ from metatrain.utils.scaler import Scaler from metatrain.utils.sum_over_atoms import sum_over_atoms +from .documentation import ModelHypers + # Data processing def concatenate_structures(systems: List[System]): @@ -67,7 +69,7 @@ def concatenate_structures(systems: List[System]): # Model definition -class DPA3(ModelInterface): +class DPA3(ModelInterface[ModelHypers]): __checkpoint_version__ = 1 __supported_devices__ = ["cuda", "cpu"] __supported_dtypes__ = [torch.float32, torch.float64] @@ -84,7 +86,7 @@ class DPA3(ModelInterface): component_labels: Dict[str, List[List[Labels]]] # torchscript needs this - def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: super().__init__(hypers, dataset_info, self.__default_metadata__) self.atomic_types = dataset_info.atomic_types self.dtype = self.hypers["descriptor"]["precision"] @@ -274,7 +276,7 @@ def forward( if not self.training: # at evaluation, we also introduce the scaler and additive contributions - return_dict = self.scaler(return_dict) + return_dict = self.scaler(systems, return_dict) for additive_model in self.additive_models: outputs_for_additive_model: Dict[str, ModelOutput] = {} for name, output in outputs.items(): diff --git a/src/metatrain/experimental/dpa3/schema-hypers.json b/src/metatrain/experimental/dpa3/schema-hypers.json deleted file mode 100644 index d9a9f170cb..0000000000 --- a/src/metatrain/experimental/dpa3/schema-hypers.json +++ /dev/null @@ -1,295 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "name": { - "type": "string", - "enum": [ - "experimental.dpa3" - ] - }, - "model": { - "type": "object", - "properties": { - "type_map": { - "type": "array", - "items": { - "type": "string" - } - }, - "descriptor": { - "type": "object", - "properties": { - "repflow": { - "type": "object", - "properties": { - "n_dim": { - "type": "integer" - }, - "e_dim": { - "type": "integer" - }, - "a_dim": { - "type": "integer" - }, - "nlayers": { - "type": "integer" - }, - "e_rcut": { - "type": "number" - }, - "e_rcut_smth": { - "type": "number" - }, - "e_sel": { - "type": "integer" - }, - "a_rcut": { - "type": "number" - }, - "a_rcut_smth": { - "type": "number" - }, - "a_sel": { - "type": "integer" - }, - "axis_neuron": { - "type": "integer" - }, - "skip_stat": { - "type": "boolean" - }, - "a_compress_rate": { - "type": "number" - }, - "a_compress_e_rate": { - "type": "number" - }, - "a_compress_use_split": { - "type": "boolean" - }, - "update_angle": { - "type": "boolean" - }, - "update_style": { - "type": "string" - }, - "update_residual": { - "type": "number" - }, - "update_residual_init": { - "type": "string" - }, - "smooth_edge_update": { - "type": "boolean" - }, - "use_dynamic_sel": { - "type": "boolean" - }, - "sel_reduce_factor": { - "type": "number" - } - }, - "required": [ - "n_dim", - "e_dim", - "a_dim", - "nlayers" - ] - }, - "activation_function": { - "type": "string" - }, - "use_tebd_bias": { - "type": "boolean" - }, - "precision": { - "type": "string" - }, - "concat_output_tebd": { - "type": "boolean" - } - } - }, - "fitting_net": { - "type": "object", - "properties": { - "neuron": { - "type": "array", - "items": { - "type": "integer" - } - }, - "resnet_dt": { - "type": "boolean" - }, - "seed": { - "type": "integer" - }, - "precision": { - "type": "string" - }, - "activation_function": { - "type": "string" - }, - "type": { - "type": "string" - }, - "numb_fparam": { - "type": "integer" - }, - "numb_aparam": { - "type": "integer" - }, - "dim_case_embd": { - "type": "integer" - }, - "trainable": { - "type": "boolean" - }, - "rcond": { - "type": [ - "number", - "null" - ] - }, - "atom_ener": { - "type": "array", - "items": { - "type": "number" - } - }, - "use_aparam_as_mask": { - "type": "boolean" - } - } - }, - "cutoff_width": { - "type": "object", - "properties": { - "d_pet": { - "type": "integer" - }, - "d_head": { - "type": "integer" - }, - "d_feedforward": { - "type": "integer" - }, - "num_heads": { - "type": "integer" - }, - "num_attention_layers": { - "type": "integer" - }, - "num_gnn_layers": { - "type": "integer" - }, - "zbl": { - "type": "boolean" - }, - "long_range": { - "type": "object", - "properties": { - "enable": { - "type": "boolean" - }, - "use_ewald": { - "type": "boolean" - }, - "smearing": { - "type": "number" - }, - "kspace_resolution": { - "type": "number" - }, - "interpolation_nodes": { - "type": "integer" - } - } - } - } - } - }, - "additionalProperties": false - }, - "training": { - "type": "object", - "properties": { - "distributed": { - "type": "boolean" - }, - "distributed_port": { - "type": "integer" - }, - "batch_size": { - "type": "integer" - }, - "num_epochs": { - "type": "integer" - }, - "learning_rate": { - "type": "number" - }, - "early_stopping_patience": { - "type": "integer" - }, - "scheduler_patience": { - "type": "integer" - }, - "scheduler_factor": { - "type": "number" - }, - "log_interval": { - "type": "integer" - }, - "checkpoint_interval": { - "type": "integer" - }, - "scale_targets": { - "type": "boolean" - }, - "fixed_composition_weights": { - "type": "object", - "patternProperties": { - "^.*$": { - "type": "object", - "propertyNames": { - "pattern": "^[0-9]+$" - }, - "additionalProperties": { - "type": "number" - } - } - }, - "additionalProperties": false - }, - "per_structure_targets": { - "type": "array", - "items": { - "type": "string" - } - }, - "log_mae": { - "type": "boolean" - }, - "log_separate_blocks": { - "type": "boolean" - }, - "best_model_metric": { - "type": "string", - "enum": [ - "rmse_prod", - "mae_prod", - "loss" - ] - }, - "loss": { - "type": "object" - }, - "additionalProperties": false - } - } - }, - "additionalProperties": false -} diff --git a/src/metatrain/experimental/dpa3/trainer.py b/src/metatrain/experimental/dpa3/trainer.py index a31d69440f..a42ad7e87c 100644 --- a/src/metatrain/experimental/dpa3/trainer.py +++ b/src/metatrain/experimental/dpa3/trainer.py @@ -14,6 +14,7 @@ CombinedDataLoader, Dataset, _is_disk_dataset, + unpack_batch, ) from metatrain.utils.distributed.distributed_data_parallel import ( DistributedDataParallel, @@ -34,13 +35,14 @@ batch_to, ) +from .documentation import TrainerHypers from .model import DPA3 -class Trainer(TrainerInterface): +class Trainer(TrainerInterface[TrainerHypers]): __checkpoint_version__ = 1 - def __init__(self, hypers): + def __init__(self, hypers: TrainerHypers): super().__init__(hypers) self.optimizer_state_dict = None @@ -137,7 +139,11 @@ def train( if self.hypers["scale_targets"]: logging.info("Calculating scaling weights") model.scaler.train_model( - train_datasets, model.additive_models, treat_as_additive=True + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + # TODO: fixed_scaling_weights ) if is_distributed: @@ -178,7 +184,9 @@ def train( # Create dataloader for the training datasets: train_dataloaders = [] - for train_dataset, train_sampler in zip(train_datasets, train_samplers): + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): if len(train_dataset) < self.hypers["batch_size"]: raise ValueError( f"A training dataset has fewer samples " @@ -206,7 +214,7 @@ def train( # Create dataloader for the validation datasets: val_dataloaders = [] - for val_dataset, val_sampler in zip(val_datasets, val_samplers): + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): if len(val_dataset) < self.hypers["batch_size"]: raise ValueError( f"A validation dataset has fewer samples " @@ -306,7 +314,7 @@ def train( for batch in train_dataloader: optimizer.zero_grad() - systems, targets, extra_data = batch + systems, targets, extra_data = unpack_batch(batch) systems, targets, extra_data = batch_to( systems, targets, extra_data, device=device ) @@ -317,7 +325,7 @@ def train( systems, targets, additive_model, train_targets ) targets = remove_scale( - targets, (model.module if is_distributed else model).scaler + systems, targets, (model.module if is_distributed else model).scaler ) systems, targets, extra_data = batch_to( systems, targets, extra_data, dtype=dtype @@ -365,7 +373,7 @@ def train( val_loss = 0.0 for batch in val_dataloader: - systems, targets, extra_data = batch + systems, targets, extra_data = unpack_batch(batch) systems, targets, extra_data = batch_to( systems, targets, extra_data, device=device ) @@ -376,7 +384,7 @@ def train( systems, targets, additive_model, train_targets ) targets = remove_scale( - targets, (model.module if is_distributed else model).scaler + systems, targets, (model.module if is_distributed else model).scaler ) systems, targets, extra_data = batch_to( systems, targets, extra_data, dtype=dtype @@ -424,9 +432,6 @@ def train( finalized_val_info = {"loss": val_loss, **finalized_val_info} if epoch == start_epoch: - scaler_scales = ( - model.module if is_distributed else model - ).scaler.get_scales_dict() metric_logger = MetricLogger( log_obj=ROOT_LOGGER, dataset_info=( @@ -434,14 +439,6 @@ def train( ).dataset_info, initial_metrics=[finalized_train_info, finalized_val_info], names=["training", "validation"], - scales={ - key: ( - scaler_scales[key.split(" ")[0]] - if ("MAE" in key or "RMSE" in key) - else 1.0 - ) - for key in finalized_train_info.keys() - }, ) if epoch % self.hypers["log_interval"] == 0: metric_logger.log( From 895e8580398842b1a1ec4228de3b3029caf5b9c8 Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Thu, 29 Jan 2026 15:05:50 +0100 Subject: [PATCH 15/18] linting --- src/metatrain/experimental/dpa3/trainer.py | 14 +++++++------- tests/utils/test_architectures.py | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/metatrain/experimental/dpa3/trainer.py b/src/metatrain/experimental/dpa3/trainer.py index a42ad7e87c..848b78b807 100644 --- a/src/metatrain/experimental/dpa3/trainer.py +++ b/src/metatrain/experimental/dpa3/trainer.py @@ -1,7 +1,7 @@ import copy import logging from pathlib import Path -from typing import Any, Dict, List, Literal, Union +from typing import Any, Dict, List, Literal, Union, cast import torch import torch.distributed @@ -23,7 +23,7 @@ from metatrain.utils.evaluate_model import evaluate_model from metatrain.utils.io import check_file_extension from metatrain.utils.logging import ROOT_LOGGER, MetricLogger -from metatrain.utils.loss import LossAggregator +from metatrain.utils.loss import LossAggregator, LossSpecification from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, @@ -47,9 +47,9 @@ def __init__(self, hypers: TrainerHypers): self.optimizer_state_dict = None self.scheduler_state_dict = None - self.epoch = None - self.best_epoch = None - self.best_metric = None + self.epoch: int | None = None + self.best_epoch: int | None = None + self.best_metric: float | None = None self.best_model_state_dict = None self.best_optimizer_state_dict = None @@ -243,7 +243,7 @@ def train( outputs_list.append(f"{target_name}_{gradient_name}_gradients") # Create a loss function: - loss_hypers = self.hypers["loss"] + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy loss_fn = LossAggregator( targets=train_targets, config=loss_hypers, @@ -538,7 +538,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): def load_checkpoint( cls, checkpoint: Dict[str, Any], - hypers: Dict[str, Any], + hypers: TrainerHypers, context: Literal["restart", "finetune"], # not used at the moment ) -> "Trainer": trainer = cls(hypers) diff --git a/tests/utils/test_architectures.py b/tests/utils/test_architectures.py index 0d98db0f18..bb233a8ca5 100644 --- a/tests/utils/test_architectures.py +++ b/tests/utils/test_architectures.py @@ -35,6 +35,7 @@ def test_find_all_architectures(): assert "deprecated.nanopet" in all_arches assert "llpr" in all_arches + def test_get_architecture_path(): assert get_architecture_path("soap_bpnn") == PACKAGE_ROOT / "soap_bpnn" From e2b1e34137e392cde500217297ca2c8641a00084 Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Fri, 30 Jan 2026 11:55:17 +0100 Subject: [PATCH 16/18] Moved to shared tests (some still failing) --- .../experimental/dpa3/checkpoints.py | 0 src/metatrain/experimental/dpa3/model.py | 22 +- .../checkpoints/model-v1_trainer-v1.ckpt.gz | Bin 0 -> 18518 bytes .../experimental/dpa3/tests/test_basic.py | 67 +++++ .../experimental/dpa3/tests/test_continue.py | 103 ------- .../dpa3/tests/test_functionality.py | 270 ------------------ .../dpa3/tests/test_torchscript.py | 91 ------ src/metatrain/utils/testing/architectures.py | 16 +- src/metatrain/utils/testing/output.py | 24 +- src/metatrain/utils/testing/torchscript.py | 10 + 10 files changed, 130 insertions(+), 473 deletions(-) create mode 100644 src/metatrain/experimental/dpa3/checkpoints.py create mode 100644 src/metatrain/experimental/dpa3/tests/checkpoints/model-v1_trainer-v1.ckpt.gz create mode 100644 src/metatrain/experimental/dpa3/tests/test_basic.py delete mode 100644 src/metatrain/experimental/dpa3/tests/test_continue.py delete mode 100644 src/metatrain/experimental/dpa3/tests/test_functionality.py delete mode 100644 src/metatrain/experimental/dpa3/tests/test_torchscript.py diff --git a/src/metatrain/experimental/dpa3/checkpoints.py b/src/metatrain/experimental/dpa3/checkpoints.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/metatrain/experimental/dpa3/model.py b/src/metatrain/experimental/dpa3/model.py index 6b582c6fe0..d593a6ca88 100644 --- a/src/metatrain/experimental/dpa3/model.py +++ b/src/metatrain/experimental/dpa3/model.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Dict, List, Literal, Optional import metatensor.torch as mts @@ -22,6 +23,7 @@ from metatrain.utils.scaler import Scaler from metatrain.utils.sum_over_atoms import sum_over_atoms +from . import checkpoints from .documentation import ModelHypers @@ -345,8 +347,10 @@ def load_checkpoint( model_data = checkpoint["model_data"] if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") model_state_dict = checkpoint["model_state_dict"] - elif context == "finetune" or context == "export": + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") model_state_dict = checkpoint["best_model_state_dict"] if model_state_dict is None: model_state_dict = checkpoint["model_state_dict"] @@ -407,7 +411,19 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: @classmethod def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: - # version is still one, there are no new versions + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + return checkpoint def get_checkpoint(self) -> Dict: @@ -419,6 +435,8 @@ def get_checkpoint(self) -> Dict: "model_hypers": self.hypers, "dataset_info": self.dataset_info, }, + "epoch": None, + "best_epoch": None, "model_state_dict": self.state_dict(), "best_model_state_dict": None, } diff --git a/src/metatrain/experimental/dpa3/tests/checkpoints/model-v1_trainer-v1.ckpt.gz b/src/metatrain/experimental/dpa3/tests/checkpoints/model-v1_trainer-v1.ckpt.gz new file mode 100644 index 0000000000000000000000000000000000000000..fcce107c7399531b782fc53caf802c8afdd3b310 GIT binary patch literal 18518 zcmd>lRa_j+w{38jpaFsf3+_%BNJ4OT*Wm6lFhPTc0KwfIf;$9<;O=h08C+(VJNdr< zIS=Rl?%RDkFVoZ2yVhQN*Y2rRHQm%v=;*G#em00O2Pccq_UxYA#_q1BHjbZN;a@q- z?VR0Tocdz`3Fk;1mkY3>eybrP)UZtl6^?aEkiOMP`uPJ>M-&yn%|jTk}yONrK;E{mrzIiXS5* z87qbL%X#(3{H^I|(X5ZFsw!P2!N_&A%K>w`j*YcdY;CLRnd)1vbMIZ&`0IeMl2?3# zW0BWeO<|2V3UW5M+Kyk5%1)dudME5lGDp(KypWDdJ(I1_h~s2C(rnbVmXmI2#fC~7 zM%WybsVgcDTGvMEolM4uLX{|(cNe`YNU z9p>8IGE}jZLiF02p{{@9k}jrl4Z&NKGrE(7|LA(a)wQNUaNj>N~^ROrx31> zLZAuvGt^XgY#!vBTtBqi@Y|il&AU}neG;ep+DQ7mTi4rkF3Q2$eK-2xSe-whDc!vA z5iqV!amIcA+N-j6!~@;!`Et502WMfap>OX-Bs{@`8+zKgkp9Od`J9u?AkC&5RC?u_ zWLLh*meBHAq_QgMSSq8q@cDKpaJ_Iy-SI69&q~3y3Ie03>$b69E8Uw}EL&c8GAexIR9?50dUq%1p8O%P zeuP%>Ok1fwuWEHK<#4wc3Uxka01A(qYD!Nfk5i!a28QY@G(-bVQ+Q!@>^HszVkH_W7A zLtLblV^>WspW6^Q^}-f4W1Eue7%%OZD6R(@V|8uA{S@6}`f=N$Usaf5``7OTe)vTz zM^YSob7e|SEl=YWjQ0D1H#FvUMs7v6P#rpw%ZZ*Wi^2;OgCv6HL}}*3$wouJzcr@F zuI)YCGa6VIR+x%x6?sJjIo=pq@>&?Z=As)hUEaK@JJ#_^6r@-e2|1{z+{xQ`JWMmJsGpQ_{2`o| zwQSMh7iGxEj3Q5wO;tfYWB9}Fr6r8PyL9Ke zN%hu?Ac7WHR3YDH8bqZ1#J1k=& za}{`JwM^ZyE~^>uu52x}$P{)IGV8)T>TVQ1?|pogFll@?{ju{zJ!NS!G4D>>QfW`; zt)^~bF%Va~N{&g;Uz?w6j;r^{_xF~wb!r+}^kCNJ^?ag_KWKX+#@~sf0S8-MTMLcT zXk!HnKzOMxlKirCSN+PeCyMGh&G59!W%h^5O{KFNFSh=0V)L}|{F`SKg=0N5k2!8P zR?Y^dpQB<`8AzxVFZv(Y)Qx-C-kwT+4t4qSgtjo)8DW;I+&RruXvJ|*(41wnSNe_YtRkb1nKT37>)AC%cZc&w|z~%ttB3vrumHImZ^oQ0FJxVCQ81IUJRH(U_@G(&0~TcNe0_O2lNl3nzB9%{NSRHXO9~U_96;68>!m2M7#s; z&*uzT?h90vM6kKFRhL9)^QNCf+=i5=iUMf`>qavfMhy!Y8E8#`_86Gfrf0V(?O*Xy znxk8iL0jtwSe}3pF_j)M9orbOSMBC#4EHK3Vyss`IoIj$9A#5?TlGEqi9fm-wA{i1HoSR+ki11QTY_Z!cgwP zueN{>`B%L0N+Tz?g&S{wv_)`w)p1@^CHXXl_*B))xcqsvMe|BIbtM2@PN02V;O2DR zQrF#=dbI$&<$S0GEm>_5C;2pl_^^i^p|MFk`vv3tN%il9++jU*jUT|mIPcThuCy@R za@-51H@T*Ru`WEFfkg{Xn(l_QsLH&ZV~K5r?Gan5LhlGeW(Q$|bzr4z@~#Io~BbJH&u5pY3#KSc^^3`_q0@eBrL{%Z6VP1P;F+tL1GFJNFInm*!b3?LkDaYn5|9^V7qcT2`RTmI74 z*xnWgb{R~LsK>e~T8O_DcZy&W?mc%P6Hf3-`Wc-eN+&R>pqg%4 z*Gxfr8aN>GvmnYOnp$DkUa!7t)Je@T%O1MN9FM|Kwp{$*!V~%AKU_vi~QUM^fq_g%yYGP>bB zOH`Cm3oJ|55;FFbsX#LS*Q0+TB+UjsM$bkwn+&j})Tr=kahW6W1}CPuM6}t2_eH1f zs`2_D@IC~kBHCKMO(8?I#N09YBpoI{sxyRcyWjIA^-cGvjIAY63K@>&z+3(uBJBYV zZgU3n(o`1&-W?QP!je=oF5Z`;KQ1g8GpDgDKDZ8!R7GNXjaAsEZrn4Zf6+7*Xq7RW z{>1p8P*iJCYZ4FWp;P~3@6OGyhxw^EXv(I2D)91?RQV^aIU%4Ow}r2@cW{!y`wdCU z$qR!?`CK`O6E68P5Bhn?#b@^R;PBoFckLREZG6n`cRUjx_AiJr*cDs)(!Xg;KD$kM z364)G966BOXko0cZaN5@pZ0ir2mAPb>r2hjc*tP*S)k`U5K}^KYT;8-uIJqUg_WQ< z@&PY4Tzh*&Lg%P_{5-09#z3#eXLbd6p@p|j(pPHM8O&I#vOOxpQ9MqLNZj_mWbw_B zEqAjXVpY|yVNCc}RxcF!BsqoDALcB=m|PwX8Le?t^OZmDc@i{d z7d*Bd+p|1(aAsF*FFfhI+nANs-oBo#j1#K+-Y~l~78tqqM#G-cqk+o6;e|GZzEBC) z`R}P71DEol#gkw3F6rL!&6)GIgdQGf&CA$TXZBmx9CF$!6>WW4h%>!O5{YLdW$t)4 zZk+cmc#u`(M}9W-Qun({`KR+J4a3# zU32*p)E*6>9F$#56yq%2ENuF9QOShITg1ew*u;vo;R<_k3CELsXU@i5z?wi5Qvqpg zTWRHk=&A0}k+AcPShFv!Q6}Rk5@+N)6vi*fh?>5-CDW|46Xlx+{1#xi^T<(1PFw=c38dNMK)b~z~W zD%>N2GbyGWsb(Ox1u4hNzq@%OIqojR`m1$)FaD{2_$>l;r5{aNb)&P&kniiWG|UCy zB`pr40q?8h($u{!3kmUUVsETZ><^2{juc72Hw7z}WDgo#&LQMNn5f?Z11IkHJMPQv zSGyUHvyLx4I5)f`=b&Ptvi}$$pzzM-F ziTSWe1W3PG>;ev{3W_oyfWX(3d;tr9H3iC@iv4hHO#A2ryuSLJ zKjJZT&Ubbqaf7pqrw0Wf&B5f+A%Zv%JmhCmMA#s-t12*KQ}ezYhzsjr+@s@|*R zsFpRW!onvSWhkuIv{rphh)79(>rUemEk_a8t1??$EiCG~qW;NXo*$0f_QeFXq-LOy zQxf^+E?CG&Pae|G*D$v<(ybb8zR8s2Qe-PpcZG(zBjncgzYiuhM9-UV4Kfp~+0L<^ zm>(9VnKopLCxkKs4iR|X_4A-Iu%fO06eBJ0Wj=Xd^Mgx#0;$`*#>A*1IbX&J|CEk| zJ@SE^^mMs1XI;6%2>*$n-c@|)mUQG>WN#s}ttiqpZ~xVhPQV7?;T6hnsa89qA_sEAo0{ctp-Sh3xSV$^t=bYA5Heo`GP1w&G4qja!)sIY<#N(T zE|%UZ{O?f`=J7hc1@1L*3aB;AX53MiiSn-8Ly3`lm*z(34C;~|ddTMss)pJ$Gm>}l z=6@S1=>0A+G}wPc!g&SP8K&cWXv6^P3loM9S`OV1+Ptj&Hre(({QIM6izjiH z0ma!m$kotnh#xaQ&LI1bF?$#7%cPi4(dXf#kIgjc=%f~eq2&bZtB=a_GP8;n4tcPu z=yXD;cEG*g;xIcrN3$dBwmlO)ZP&>k{JLK>B476jr0_2yso8Qwv(w^tg z3Vs+8rQJ1*==uERtpL_{b^<&SS3&R1^DU-3bjFaC`{9n&E~wXK!c$4iQ}+Op$&Tdu zr`b02p$`OfT9?khhUjQNHAcpW$;F7FexgF7Q@;3QwrC{LD9g4qbo&wJamHi=qc>8_ zlDpnust@+{A#XoH=4|ukjJ%`#ZuLGhrQSl&V>(5~soYlX+K}6(h~xH4lvQ`XTYHp_ z0R8s^hEMLhD7L#&F0w{>3_hu8J}9H@n3hKJ*dSxsSyAF4qX_PIOVY;jPlQrUd!0%# z9g;B}7z68w=4jsL;H4gk^`u3f@4@v(=g1;whPgZI`AN8eg}o++(2t>l9lxC%XsBiE z$%em1Jd19M%Mq<1+;LaVGtDj+o1CKSh?46IqB{A|4AKuM?DFbWPaxeqzAsrwo}JY; zIkMxrQOxTRE$_7trP+`JrVe|F8=vZ4c5Fs%yp#yYL3xVveZpvmVy_3tVuP+eUhC5j zxkT_{JtS(KG)eRakh0vOdE78HKVlg`GhdB#{NZT=!sx)C8DP_qP^PLy4859!*>pJ^ zFL3!5M?gbQ|lvrMECX%7tpIBaXrtPj*TN3m;Y>y=GnK(YZ29Ch8Om>Y*)|)vT zU?x&$YxzzmUH(_Nc(~ShDt>mG%PZ|7tz7n*H56$+#yNW|*=eWCukK5St1?YdfJ(#j z`#GMHYK05_7eE}1K|kpM6aw>4@ls!xX}~HYMN%WV`}HPDX#3Yk_h&;~G8L)ear2`@ zC8@e4siDiYLF^QZ-4P z%DY!$HcuHn4B!|4A6jY-yc*|H-uEMM^^@Wn6<32Smj@bu1DV1aM$wttppWzbNdzZF zOZsE0^lHrAv}1d$ZyR;{yBF(J8iO9v10Qn7EA%e8JZ@!j%=AMe-^m%iL#?7h(^f9o z>7X-D*8>^@GU*$mm?D^LDzoycJin9|ch4z#0KF_u2>dV=@FtMU>h(+1=x%0^1Jb%D zs65)eEW}tniBeu=5yqAr*d!X|R2d;_a|V_t&w$mUCUV~VDuoLnET@G5w}IKv z%>$3W@wL-75~g6UV=wy1|8)DZ>_c(#8d>!GFO3~TxYFah$ZDR!ol+82Fzmpq5`Uph zUAEIkw?hXPddL?pIda+PnvXvog%Jb31!bEbk$74F8B;CXciqT$q9md@LthzBy)ym~ z?L;%sbc7DSSd6=bWipp=KIbkyfOqukTQqg3HDJmb@S)zAX4&PY)XM@=mgM;2C;5<* zEe!R`#+*CSFY5R80Gs+Qn&l1rQ)&;370zdzj|HOa6`wAtEB7NL_?0nVYLbKcka4KF zRePm(0j9i1nIzFHDEo9Xb=~25BnP>vQ~whAlt=fkfXFKHW>Z7ABPh|q zFWxr0yk+GnB=Vctw=2pn9Q9{~s-4-}Vk$BVeXZGff`6{3`f9WDIIXW|;OE=G$;0Z1 zm4HZW{L$II_H45k=s2@|PJn)toRq={8$dq|{_uQXbM_lFi~2O7lj6z1A4Y^tP$E=( zKHKfq>>lJi%dx{HWm5!?a>nr|W_ub@9kI!J5@_fZE0Qe53OX+_7S+TRk_Yppr*R8q`tzIf)E0kz&vh01n1t1h_d3+Y05;GEO z4cqP7>>eDmI)0aBWp1?HGCuH{vMNp)f13As@x>R%%D=Ig(7c$+*oP7sq4(NucVy?` zx*ymo!87QCZ_>HXcV1lbl?#gx9%~K0JLf`1Ba{dSe-l#?c&pBa*wCNmeOj!8=&NnJ z4bL_6kWpd8Q$XZr8eqfU7@JYwuq|YqL5Wyt#v9U_&x>_1ddm5V;X!)fZ~2ZLt}DAB z_32jr?Lr}W&p39tqFjUerebdByttX{**B`F5K45sDSXUdyr^7*kkzQ__8Usn)I;br z^RYR5+td4zQSl-m^7i@s8NW08>neis$|&c^V`od?2HzG(xw#?nu2O&@X{VdybxHpG z2OU?^yJD9+hmdodw{Eq2iMZ8%fl>Bb+9B#&Q)#WXU(4s!jSd;SdPTT_iPFSgS{l7~ zXs~|1U8*xYI;_Z%0m}P_K#;#N65Oh+WT%>ERzz12rYOH_)aik;hz^;%OXS=Y$G_LR zZzU6NVVPk2+7>pGHIfNhJ;xd~X2$>~YHm*8-T#=lxXbu(PxbPW=I-lW;lZ!u=FtN+ zlIn>Gw-zf+*9jkV%sZ~G)^R7bt;oYoTaAOKN7k-My47r}@0H|RZXw=ob4sK>M3^hO z-hId0UD}}@!oZF1l2pN6p4jUTYGahwts}!baLXu}hK0QnhAhlWWY~Aq`gWsQ9D9d@ zJXoZ4MGETc3yY z34ADoCH6K*MyLxe&s840B%j}X2lQ#LXH81rUNxk!GY2ZluNO1i(=;Pf2aQYTTyX=W zWu_YshnQv;iphNk7hZ1VO{g9nzOwd5)yuxpi+}Cldx5BpIe~ob@ zL(Q_N&9)eenWc>K6l{WuZu!1w7{}rZpL%wE7Is0Y5q^d(2|5m9?g>{ZxeRV)mw*Z= zQw8+S46vCJu1cA{q1P=Gj-$}Z{MJOPLi7-=Aen{hq5kB@^S-qGcS&ds- z;)F+XRJQ-OM8)eq7MlKVIDN#7QkmaI3z)|-m1+n=R%!V3#icT{Msusm5Ua{iwE1Tf z(fVv9qcIZl0`Xb)al&hYbSR0_vCOU{?8OTfO9N-W%1_CVOcIfz7JCb$^M9EnHibtf zH9NQcrJNzCoVB#Pkj9fZoEU4DDn}B7wfBaVNHsS!jkF4p0=03GYvlNCiui4rkq(G8 zq7wGJ1&a-Vt7ME)9Z9QF>Z=49+iNuLc<5`a*?C6NXZdc=&~0HDf{47BOy)b}AiUsB zx=Wj*8w*xqpa$FhD`H>+5YJHAIhm|n z%AcaL?#Wi>@i47s#-j={M;2^ z1*81ZGaMxkYV=dI+=bo(TUw{deJ4hgvkycA%LqqD4)~(zX6GTYGji*NX6Gn@GLzTuzH)fhjsG(+Fku7k${FuO z?yi=&uzTRc5xK-6hG?}d;P_r*C5lsJ997d?Q=mhs z7&Va|f(2l_lL#OT$_qh{8@zBRmKiKmc1A`?r6vJ^TeDV~jdqDMDP=UOG{J{kktf@1 z0yX}}Yep$TS*6>NB6_sK6wLPy5+x3`z(~vKZvsx)3Wk}q%_0e{B3OMQUvvg79*Ecr z2#JqhGv6ZTx=EI3rw}>piY<%1)8*^* zWqKOarar3RM88GqzS^7Nvy}z0q|lO#76c)cY)as_zd){|nKQz%P)X5<5s6nUbPH0S z3;K>5a(!7!e$ywBPJu*eLLl->!BAIX6w|Dc%Z7Z&ghha!CW$qCPqhY!&-$V!7Z};; zba>nnx8t*GZCyQeu(wC&WTc|90oAMdLEBsZMAX2aymU2r9;bCb(KFc^!DI&EpJ!cUSi2w z@SQK;WuA{ebQb1tU*H4hYQHvT>?QV%zvH1IQTv_QU{AT#)szw)TAYaymB72DzoaV& zF?|gQ^b)TdKQ5!$$-P;+sI1>JxYVq|l>T0cF-fyz{DzY`Y+FGgtdKE)eV}T8y)^;2 zA!P4l?qug=<>UnD4Jd-L`Jwqq`!V@h`K5J4LGAst{RE$cZ_jS2ZnJLPPCE=dls0;F zOt#eO$2Ae>y3oreq)F%bm>8lK_-G{;6s!!1k2e8JVaIvHDti;pK-Y(w-`Tk;odEA) zBCdd%`*WqY9EJ*eNK34nHLkJ6Wj(yrVG5y+sg$ce4qH*t9TiSE$qK=$sHd+lxDL!k zKDK=nGs(Et-=e${>pOnKd*8)z$;o9K0*IG~=7+f{y&mqFexomIur4=F+`?u}l*Rl^ zQSp_a*!W9xnT&32&RBJFyldyF%1pcf?1EM_oX5j?ZK4_)kGUZsrmj=O;ih$;TxNKY zEV9ws7*zzAS;(O@Ot`5W7)MO+>a5ZBaJ|;i;SB5gU1h8F-l@&B2weGjT#00!=O83K zAi5y3bGJOWIGfx%&Tmu7h{E!b`z7NN$N;{iWsO?5fYH8{w~7(L&%g@0>QRuiMnaK*QO(ajnkTKHRPOReF1Ki?cfcc~G^dwIz}Wn86t!#qbPM;5+nE-_rf z9K}^yUlto^8hI4b1ZpEA4N39ajPTn6BP&sBC?rpcEKD0*6;p^5^=gxfyxp85o4!nn z;OlW_=gCS7&_6AZ3Nl8`lR6dlcg2XbQ6dWlZc<*_q+MI^61S_fd6N>azDGGn?e>$n zShq0kzU*`^B_D$F&7lRU`{RdDA?ALC|NhX&PU0?-|M?;vit8QS>ifDc>qU)&=`>&5 zsm96hmqqoUMO>C!zVK7z8pGk9v^`$<6qT^X*FUgHlrB+17IGURJM}(G@iv=EJef`8 z9WaXohH`Wwg%89)HpWdTN^F0RVoZ$H`3dDtt=pfbnHt$xnS_LW13kGcrd?LupH*lM ziRM}Wo$`nBTm{4VB(AsMB*HlcrE_?z2)5aI@FE&K9*^A*EVrr>C+3nTgex6BRP(MK^UX zeAWf1c4zkob*XgkXk2LHLvWBmf7!s`Y$71C>QFeFF3CUC_>Y1LWtNxJr(XyIafO)O zzp#Wl2FeApt1$mWVUZG7#F{FYoS2lw0EQ_=UXY1suo3qyf6-YXGM%b5luIG1Ebxy6 z^z*6(sfiJWj3egaGG8(;%Sa!UssGZn>eeLpm_z9sN3E&4WY()Mpdc1yqM0|uKbGwG z4O(oWnP-*KsTbo^h$NwuEoQ#ORjR>5zd^|jfE!=$0*}f-zcj-(SrS^wUzj0Do&<3* zlgp40QxM^gko38hq>%Kv(-bMcBw@#a4-)~M;uKLv9$5$*jRj6ZUc_Qac{hUg2W0J* zSxu_&y56rwSUx5N}A_^d>IVd^z6X*_Aei3`!AlIVKjlW)FtWPY&vYozy{_~aCf zQmkaj+u2ux#o_x94CZIuycCgi^dJTrZ`Bk&2MYq?3^tOC_b8l2L4`Cz#K^^eQD)%f zH5V^$UB(-v-a!=dn#Mg9Sf)kkGL8FI{$(Y1Qt`d=C2I*UDf&2ax9(-B-c3O~W@~>? zUFqg)79EtJIU1dWH*9*oJ%zS34Sbr6LYj#0jquI>)~rYU!zz()9zgk(-BoG1FTVZ@ zGc5S{zDZWE3H!209JKg`rSVl?(wD_VW_X=UJ2R&ZYifAvI!9^;Px8~$i^0=d8zj&r zAXRTs7f%8r_VA=cH^K`N{5$VzsdrQK2MMZn^J_S>qpxEJH5I^Pfy)tFy5rL8uuToY zpvg)a_jF*H0DDWOzy{HuFQOs1iwcHo7-n99izE+x4kE6C#BEG8p4Bv-)WmJneN-=M z(9}buM!xnGSkEMUEj^*@W_Co$m6rSiFJbbbo%#nt1nm`M?UmAUCdm_K{FRrCQgl9P zk*C<iBPRkUuV}8{!v}F!Y8k=+zV}W2R;^u9 zr!9HGkH1R6sED;BjdB--Vq6^Lq|z-YdBKh^_8p$kPD&0~ehIB~;&3R13dY;=mb z93cv*++~#!uy%13=%N=+HwIxFwWgiwTV~+MN7b3~y9iX!jE6DwF8vf2qR>LI>Yd)O z-|*d7V0O_xFCXD(EH1b;nw-mQ{8YHDYkvShpyAhKko3|7OvOY9FY&%BSe8A{U}aIz zRun~X2mYi41Pp;b)J9s}n;euyOLn(+a>lz0m9N&>mKpNftLEPR0KYdCxJ1R==Y~j^ zTpDLA1_knIRxL{CzB|7Bj!DYQQvZo^A(lmOM61OcZ{+eDSexzsSdiWOPYL;(+g~8C z`M1WB$-LPY)Xa*LXPLU|mH`t9h-xjsgXW`-CH{sRj56G^@(`;BP!$=##{fKwJZCW1 zbB(D{@4 zq1NM<>#*y$-JH&4L2I0b+B&CR8>w#RZIt}d>AVZNng@9iY$W7tj?8hk5U-$}PT46BBzTrRO*Arupn~8)x@A0S^FQilp|$B5FHJCX2iu;hSE8 zcRWkaqh65D$+*QAc?sU!M8vpiQe+^Pz!V^a*(TALIm$@`pE^*21|J<6@qhCH+!bQx zCujRa?UkmXi@}oB@7eIFP);#Xye{81rL28gQ@nV7r^?igC8L6OP(vf7IE@5Lz;W?= z?8i58GV}xy|CgSO09TGvXIM1hV@3Qem>lmm7BJku@b7fQU4Bgst`#mVqAMUB+ zJmr7%H?e#JttB0&D{B~N{2H8~zVm6Ry%*>QQ+Ea!?vw;9kch2!C|~WEJKF8~pKLhC zyKlbnu%41;O4Qo^P|VIU8u?qM(Zg`Pb_?tBzS$v*rCQkIIa5Q+#=uZDS5*~ZdX4pp zeXwaykgbGGCR4aRXKj8tC|z7H36SIeDgWrB(<{<1tHFtaHO#ABpFPlJJa!_SbSh=uNZ>Pv@h zDa-%> zyu#zbG@5*v?37W(1#+`~?y59W`*~NrHwh!{n)U1Kfw0S6!Tk2#J-qYe6R_;O_dct{ z!~U)9C(Z9KWGNK@lb64`Q(BFC8m;Azu{oaKJBRL^7SS13j)V9e3>|lRtggB`EPQsF z6pcy-&F>n?xPQ0_ToZMEmGgHav~pB%qB%b?-hkZL2}(eh+b$I_kFC`8y7;3W`Lj1Z zyLNLvxm`6@`%oR#xVztCe7-6cKYHM0^|F}pnTg>7RS9uH7(YK86HjC(X}^m}zrC@0 zK8qC?TlnlXzTvyuP}D~5g{wP#{cz@Wdp_zQml=1!SLv7*H@NQrWo-Z-5Z?G~;B*Oq zt_nsW`FHk68~Ool4-}twq^n~+AoRup)rSXHt-tPdYtLtbF^`MApV#hx-?l4$%?c5@ zzF2irFa_pM-#-ijM)dmvP{GyBF4qR3VB(1`COrFzC#O5l@a!a!z0aZGF}a=djV8`m zChD$@{>=He(M~pd_w!ocs=c4vSx@g)hpkr9HaeX>W~PX*D6&r$Y0rZfI>%X0j2C7{ z9D^SCU`I^HiIgC+`eKo1OwFww#YVtYX6?gDPL$^}hdk^ZCs>9_==ou1Vh;2RZ<+=OK=0ki z`Z(?VRsur74mEuT=7ktX9lRqeqvt+q5@jrk{DYe3V^yV<%A5On|Mxr{jhWn(VS%lb zot3&ZG$c1Mzjbf?|xXTp2`>dIDDZKvLVcftslLs6MHlX=_)t0t#;45H7c8B#h{vKB_Upui_=DfsX5z! z_BsE@sa-Xh<^Dc%7WvDt+)50lx`dMIl8 z4iAXJJge(2TDTYb&}=?d?hDv;X(a1B2y$yP5_;fy%Bs3+nVVmbzeQ#$rPbb-;;F7C z(W?-b@=9#8)ge{otER|LU3eg+0uUC2qY90KB`fd2ZMv^gBKN%m>?=qf#k{|55L~@B zrha&B{PEyEfx6-WNG=$wDBAye!&oHW#p<*Fo0KUd=o|Hg39?DORZg6JKfo_WNiBhfSk$R)j9g+`S$s}{ZOQ-Ca@MRYqne6Q2+|@q z#0c@g0&5KH;^vrK`lhXM6MW{MMosOePj_w(3ri?KA}#G@23i-2V@rkVsy+-gR*mQj z041rYo5s}l_~8>g%NSs-gJ;7Ps2|W9^gC!a^#j(jLnV9jVgBa_q-86N;IR-9A+PZa{6><#KljISOH8A&0v070jxw1ZqKd1#-xv`vBi zSLee>D_74e9hBR}g-SlUroi;+`=`6pk)7eeb1#=-{@>gB^_8!U7k|=DVcGmA!8awdknb;f&hS7&Ojh0k#j{T6L-m zfb^rM#y376b{#NK9U>0`x_}LOTj1*cg8FAKYGd>o<1?M4F%%6OrI(`Gr)<|L@%bO`9^| zDqkm0MF)RGK~h5ou@RWt*!}QMWHDpOei=!Tl}2;>9=*Bq=m9?+X#n77WGWZraXA zTR6)ve@uY)t=;xT%8v~A;`ZJL@3Aw)$BE9Qz0CFdj4}C9Afv^&*;mhpmrHxHz zJc~pKSh*zRcIwvuNvH+AE&H?=rhreSOcYQiHK>@)(OU5I_UukomAXrZI2+^|AT z@%5?+`YB&~37nt_2G78hukHj#sxm0LAsxEbKZQB2nz;8L{3cR(SF>Y9=iAHfK2TYu zZ~1WuH64d0aMqn4`ey4p_;*++JI9l|em0`NdbG{$KT>zsZYdA*cqw(&%AMwu3EBel zN=90L*?CPp8%f=j9vWC`{q10Iz&3Xx_~4$MdjfnDHvyJ%{|0#IPMaupxmv7O9|c1T zdC9%B=9z-XPUkpzJfIl}mHp6Zlg>NyiC>$bCO+8llhbm6^K;YNY}i-a9rJ9j-6%m~ zs;)uzXMxkZC)cl9V$D%^o2@-Ml^0NH@T2D3=?At6puWlGP8+9)T`+vL>2N%^K{E58 zFWbAgdcwb?0&^GvbQ+jGM=n(JiE5|GKa=Gj>~!4@5AlV9aVm8{=4?0TT~J6SOr+N9 z#;9mz4tx=A<;{5x1?NjGJ*JKkoj&Tem1;S-Eg?fER+x_QU@fCVgUSs`T{@;fw|el7 z3M^&5n$h29<9n5;=iVjN@;i}TfzTag@MDKd{Su$2S7>PLdGl=~8u36Abb}0*sriAv z#zW8XvhMBamzL_)cNPzQyY1Pi9ve5Rv7r;!n#uwP^Y5&M>(4o^mWPM6)>o+Z2hNo` z0@idrvZ4hYT~DUsoAU4)ik?0}Eyz+g471k`*mt7TwQvof$u26%9j#dFtk3@V)Uez( z0Uo)XW963ssvE&R6Z%DLdEmim__l^XyJ_-S0O$f$7LNa$-cQHKX2p4@jmE zP_yrtULjZ!b7wIBdH>ZBz`du3Qkqco93pD>Ff&1U-h6x!RqhEfdFNK#b{=2+6_nk0 zJW^Lh4U`O+SZ(xR78QMNAz3fZ_TIYyt8W0NbUl)ywH<5inQm<<9^TJ4Z>ZOMi9C-# zLwCMC!XBY|tWG$kYXO}_<>IjFkG)$*?pGiDyC!_p7hZl#@C|UvDCg_srw^I~@lQe*vD*48ABz*a$0(-KmBDYVeC9g95_7u=BDtFby zJPPs;{yhoui;sx~L7b;WLcur7^W|Meq#vN=)z%43R@zGIMh7LUTk6UC?*uCyhkizp z9z5%}*$$ENfAt@fGEb&f_l6V$=D@}io*>UxEK-r?uuJy*!r*_n7uJ@Rl~MdhIaYkGdir)LAv+3%uz zG7fp{60@6n3ooKd1DMV@$f-mR`!2TQ>ap*2A=o?>_UW6y6&E$kWaA~V+>X9}%QGC= z7_yfR8(eSoi2S=U-tT&3?tB1^eh6MyB{yDZy%8MXfn_c>zC1F%xzG9Z4@zYA8S2$W zA>P=Nv<=Xlbpep{zzP4Gxpolrsc;Z9WeoZeyAgTNHIVWRme=|GHoI&3YU58e;P-%F z!}dcpbJBzI!_DA$*=sV+RReEHKBFIhdpL#jz{?Dt4^rrj^aHr~8X^eCyaelmAqy4{ zpu0!9u9`;vYQQChrSjcltnSL=0`*wye8=;U(-((Iu#*7f+yJyI_`7lcMmnI^uDu$N zHdYCOPTAY69mhgY=d~zcjq_d4)>DmVH`4N~MzDrI`mo3R{>jz_YW7NR=z^Io^f;Ep z?brxZUk|6*{|JK{Mpb&iU8d_R-+_xY7vL@{AK@;o`Rmv3o5=O=9`omB;r%Q3=M7=? zdU}7)m;atu0-*~<4sgc;hJW!IE^my}bKywK(UpGCsrlavAcy_GICb@KANs7uu!beQ zo5y_dA}IBgAqaa3#{19v?}}_sAZ!QI+cmx@7&gLUBAK}F`6esw()Z2OT~O6oJwFTa#5lnLt$&iCXfcb1$zUusha$c}i6@?|K@Fn}jbG7M4N={% zDiD>-AfM5NXSFPff-9}?Z^<|K2W-i~hrAQ&Rx1lv(cfT>3Wr-N8#R$na}KynAV$D)jtz6WqwS-XL)UZM$t zuDtWEuMXa)M*-x7ntcwdr}vX4_qUZ1aUBC>idy*L@JPT9c`!x?M=S@9=hy2>-HFLI zRSDRl;DbHXCQhhfQHRB91OQF28LI4h%7CN{d8nra!@k>alKE^;d5dDQItE1eJ9BMr&X;^2c4)TZ-`CI!FY3qsJdjq}qD9BxM$ zB3#scIs=pa;RC~U;CvRF`hNeP_MGn1C|Dxg4`_P=jO_@Lsp_zRL;$|>7Qmhp{%xMZ zj^JD{;(6}NMj|7Cg(&#B?VjQdH#B{G2`hj*x+eq~S;zoeZX1)+6ZQ)kKN)nLN1JR2QT(^Gas_i=XBc83@6i78RR&wqEY8s3-U4ItoDi+yJO z)dTq0-1Ss3j=IaUJkG}H8|2<2*1%65Znh@Ib6N>+KZ~<-ih|rX=sUpCsXqol9*Ye8 z4B-Za(f{IQ!}-CR_2ev;Dd^z1n~63~k#L+219bFnjU!mAy<`X_m>-grBP7_ z!{yFGgX=%sirBx~{3Yj6so+;G0B^4k=N_OxS#i8z^iDTqy;v;C(Dyf`D)py9glGr zBK^-u-oZ7sG3|^IwG?#kFNQZvMh05_Z&^aAY0qJQ)A`h8U-%y^^wKW)O>j?mo{wG% zV@bW8e|z+Ai9QhrRih}ja>$-tt-MF+qreS935X}2Z1^GhQ{wc-6)~Rn?K&nto3o0R z1P^Ne-6a~IRGqikAJ@x_8B z(n$vI5JeYZScYgrm1-JGkYU(AQj4gAhLwgaV%x9vgRf~!#up3vNT=VpE;Otr;bNiF z4RD`vZ#&sV3P>$b=;>5}#%|%k>3hhds@FkqO!~-kHr&Hd{axh0*X*3bk#rICi8_A% z4IrAHH~tUf>UX2>#NVSof(`3w-L`#&5&UhP|rN)kK- z{Y?>J|K;EcyR^H~6>)w@UFhvU{Y2S6yk1SBzdPi~McmeB`0#oN^CS37?9fyApU^5^ zZA-w@L(uO<5hk=q&xGIKb|v|5)aRtzo4Eg4{Un83(L{P3$ik!fOMb-QFZsW@xgy!# z9H(!YcnHt?PtXgx;lXYH#<-#bHw%2bp2_48N66LaUP}hY`87MeWI2@K-IO80cRpT= zn3Oaw_`%jWHseb}A>lYaS}o#KyXZzb{hbXKd#39@kyRDCjhh6HoFx#3Z?8S5>mq!R zx61rKQ`}pye@?^3rI*ivLapQ3-Jgt>jKxpX=dq`<>P(zpam`Kx>>dA2OxsTQ+o;=4 zQQ7|8_LJ~=VC)C(v7T9`?I6Rb{)eKE&dSbFA%^Q9S6m z`zgyJ%kpWW2Cj9FELR5l9gz9i`P?xQWWekvENg%!$W0O50G71$%;i{uAtR4`zL3WdHy dict: + """Minimal hyperparameters for a DPA3 model for the smallest + checkpoint possible. + + :return: Hyperparameters for the model. + """ + hypers = copy.deepcopy(get_default_hypers(self.architecture)["model"]) + hypers["descriptor"]["repflow"]["n_dim"] = 2 + hypers["descriptor"]["repflow"]["e_dim"] = 2 + hypers["descriptor"]["repflow"]["a_dim"] = 2 + hypers["descriptor"]["repflow"]["e_sel"] = 1 + hypers["descriptor"]["repflow"]["a_sel"] = 1 + hypers["descriptor"]["repflow"]["axis_neuron"] = 1 + hypers["descriptor"]["repflow"]["nlayers"] = 1 + hypers["fitting_net"]["neuron"] = [1, 1] + return hypers + + +class TestInput(InputTests, DPA3Tests): ... + + +class TestOutput(OutputTests, DPA3Tests): + supports_multiscalar_outputs = False + supports_spherical_outputs = False + supports_vector_outputs = False + supports_features = False + supports_last_layer_features = False + + +class TestAutograd(AutogradTests, DPA3Tests): + cuda_nondet_tolerance = 1e-12 + + +class TestTorchscript(TorchscriptTests, DPA3Tests): + float_hypers = ["descriptor.repflow.e_rcut", "descriptor.repflow.e_rcut_smth"] + supports_spherical_outputs = False + + +class TestExported(ExportedTests, DPA3Tests): ... + + +class TestTraining(TrainingTests, DPA3Tests): ... + + +class TestCheckpoints(CheckpointTests, DPA3Tests): + incompatible_trainer_checkpoints = [] diff --git a/src/metatrain/experimental/dpa3/tests/test_continue.py b/src/metatrain/experimental/dpa3/tests/test_continue.py deleted file mode 100644 index 3c0616997d..0000000000 --- a/src/metatrain/experimental/dpa3/tests/test_continue.py +++ /dev/null @@ -1,103 +0,0 @@ -import shutil - -import metatensor -import torch -from omegaconf import OmegaConf - -from metatrain.experimental.dpa3 import DPA3, Trainer -from metatrain.utils.data import Dataset, DatasetInfo -from metatrain.utils.data.readers import read_systems, read_targets -from metatrain.utils.data.target_info import get_energy_target_info -from metatrain.utils.io import model_from_checkpoint -from metatrain.utils.neighbor_lists import ( - get_requested_neighbor_lists, - get_system_with_neighbor_lists, -) - -from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS - - -def test_continue(monkeypatch, tmp_path): - """Tests that a model can be checkpointed and loaded - for a continuation of the training process""" - - monkeypatch.chdir(tmp_path) - shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") - - systems = read_systems(DATASET_PATH) - systems = [system.to(torch.float32) for system in systems] - - target_info_dict = {} - target_info_dict["mtt::U0"] = get_energy_target_info({"unit": "eV"}) - - dataset_info = DatasetInfo( - length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict - ) - model = DPA3(MODEL_HYPERS, dataset_info).to(systems[0].positions.device) - requested_neighbor_lists = get_requested_neighbor_lists(model) - systems = [ - get_system_with_neighbor_lists(system, requested_neighbor_lists) - for system in systems - ] - output_before = model(systems[:5], {"mtt::U0": model.outputs["mtt::U0"]}) - - conf = { - "mtt::U0": { - "quantity": "energy", - "read_from": DATASET_PATH, - "reader": "ase", - "key": "U0", - "unit": "eV", - "type": "scalar", - "per_atom": False, - "num_subtargets": 1, - "forces": False, - "stress": False, - "virial": False, - } - } - targets, _ = read_targets(OmegaConf.create(conf)) - - # systems in float64 are required for training - systems = [system.to(torch.float64) for system in systems] - dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) - - hypers = DEFAULT_HYPERS.copy() - hypers["training"]["num_epochs"] = 0 - trainer = Trainer(hypers["training"]) - trainer.train( - model=model, - dtype=torch.float32, - devices=[torch.device("cpu")], - train_datasets=[dataset], - val_datasets=[dataset], - checkpoint_dir=".", - ) - trainer.save_checkpoint(model, "temp.ckpt") - checkpoint = torch.load("temp.ckpt", weights_only=False, map_location="cpu") - model_after = model_from_checkpoint(checkpoint, context="restart") - assert isinstance(model_after, DPA3) - model_after.restart(dataset_info) - - hypers["training"]["num_epochs"] = 0 - trainer = Trainer(hypers["training"]) - trainer.train( - model=model_after, - dtype=torch.float32, - devices=[torch.device("cpu")], - train_datasets=[dataset], - val_datasets=[dataset], - checkpoint_dir=".", - ) - - # evaluation - systems = [system.to(torch.float32) for system in systems] - - model.eval() - model_after.eval() - - # Predict on the first five systems - output_before = model(systems[:5], {"mtt::U0": model.outputs["mtt::U0"]}) - output_after = model_after(systems[:5], {"mtt::U0": model_after.outputs["mtt::U0"]}) - - assert metatensor.torch.allclose(output_before["mtt::U0"], output_after["mtt::U0"]) diff --git a/src/metatrain/experimental/dpa3/tests/test_functionality.py b/src/metatrain/experimental/dpa3/tests/test_functionality.py deleted file mode 100644 index 8ec99deb87..0000000000 --- a/src/metatrain/experimental/dpa3/tests/test_functionality.py +++ /dev/null @@ -1,270 +0,0 @@ -import metatensor.torch as mts -import torch -from metatomic.torch import ModelOutput, System -from omegaconf import OmegaConf - -from metatrain.experimental.dpa3 import DPA3 -from metatrain.utils.architectures import check_architecture_options -from metatrain.utils.data import DatasetInfo -from metatrain.utils.data.target_info import ( - get_energy_target_info, -) -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists - -from . import DEFAULT_HYPERS, MODEL_HYPERS - - -def test_prediction(): - """Tests the basic functionality of the forward pass of the model.""" - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={ - "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) - }, - ) - - model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") - - system = System( - types=torch.tensor([6, 6]), - positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) - outputs = {"energy": ModelOutput(per_atom=False)} - model([system, system], outputs) - - -def test_dpa3_padding(): - """Tests that the model predicts the same energy independently of the - padding size.""" - - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={ - "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) - }, - ) - - model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") - - system = System( - types=torch.tensor([6, 6]), - positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) - outputs = {"energy": ModelOutput(per_atom=False)} - lone_output = model([system], outputs) - - system_2 = System( - types=torch.tensor([6, 6, 6, 6, 6, 6, 6]), - positions=torch.tensor( - [ - [0.0, 0.0, 0.0], - [0.0, 0.0, 1.0], - [0.0, 0.0, 2.0], - [0.0, 0.0, 3.0], - [0.0, 0.0, 4.0], - [0.0, 0.0, 5.0], - [0.0, 0.0, 6.0], - ] - ), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system_2 = get_system_with_neighbor_lists( - system_2, model.requested_neighbor_lists() - ) - padded_output = model([system, system_2], outputs) - - lone_energy = lone_output["energy"].block().values.squeeze(-1)[0] - padded_energy = padded_output["energy"].block().values.squeeze(-1)[0] - - assert torch.allclose(lone_energy, padded_energy, atol=1e-6, rtol=1e-6) - - -def test_prediction_subset_elements(): - """Tests that the model can predict on a subset of the elements it was trained - on.""" - - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={ - "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) - }, - ) - - model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") - - system = System( - types=torch.tensor([6, 6]), - positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) - model( - [system], - {"energy": model.outputs["energy"]}, - ) - - -def test_prediction_subset_atoms(): - """Tests that the model can predict on a subset - of the atoms in a system.""" - - # we need float64 for this test, then we will change it back at the end - default_dtype_before = torch.get_default_dtype() - torch.set_default_dtype(torch.float64) - - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={ - "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) - }, - ) - - model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") - - # Since we don't yet support atomic predictions, we will test this by - # predicting on a system with two monomers at a large distance - - system_monomer = System( - types=torch.tensor([7, 8, 8]), - positions=torch.tensor( - [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]], - ), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system_monomer = get_system_with_neighbor_lists( - system_monomer, model.requested_neighbor_lists() - ) - - energy_monomer = model( - [system_monomer], - {"energy": ModelOutput(per_atom=False)}, - ) - - system_far_away_dimer = System( - types=torch.tensor([7, 7, 8, 8, 8, 8]), - positions=torch.tensor( - [ - [0.0, 0.0, 0.0], - [0.0, 50.0, 0.0], - [0.0, 0.0, 1.0], - [0.0, 0.0, 2.0], - [0.0, 51.0, 0.0], - [0.0, 42.0, 0.0], - ], - ), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system_far_away_dimer = get_system_with_neighbor_lists( - system_far_away_dimer, model.requested_neighbor_lists() - ) - - selection_labels = mts.Labels( - names=["system", "atom"], - values=torch.tensor([[0, 0], [0, 2], [0, 3]]), - ) - - energy_dimer = model( - [system_far_away_dimer], - {"energy": ModelOutput(per_atom=False)}, - ) - - energy_monomer_in_dimer = model( - [system_far_away_dimer], - {"energy": ModelOutput(per_atom=False)}, - selected_atoms=selection_labels, - ) - - assert not mts.allclose(energy_monomer["energy"], energy_dimer["energy"]) - - assert mts.allclose( - energy_monomer["energy"], energy_monomer_in_dimer["energy"], atol=1e-6 - ) - - torch.set_default_dtype(default_dtype_before) - - -def test_output_per_atom(): - """Tests that the model can output per-atom quantities.""" - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={ - "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) - }, - ) - - model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") - - system = System( - types=torch.tensor([6, 1, 8, 7]), - positions=torch.tensor( - [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]], - ), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) - - outputs = model( - [system], - {"energy": model.outputs["energy"]}, - ) - - assert outputs["energy"].block().samples.names == ["system", "atom"] - assert outputs["energy"].block().values.shape == (4, 1) - - -def test_fixed_composition_weights(): - """Tests the correctness of the json schema for fixed_composition_weights""" - - hypers = DEFAULT_HYPERS.copy() - hypers["training"]["fixed_composition_weights"] = { - "energy": { - 1: 1.0, - 6: 0.0, - 7: 0.0, - 8: 0.0, - 9: 3000.0, - } - } - hypers = OmegaConf.create(hypers) - check_architecture_options( - name="experimental.dpa3", options=OmegaConf.to_container(hypers) - ) - - -def test_pet_single_atom(): - """Tests that the model predicts correctly on a single atom.""" - - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={ - "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) - }, - ) - model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") - - system = System( - types=torch.tensor([6]), - positions=torch.tensor([[0.0, 0.0, 1.0]]), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) - outputs = {"energy": ModelOutput(per_atom=False)} - model([system], outputs) diff --git a/src/metatrain/experimental/dpa3/tests/test_torchscript.py b/src/metatrain/experimental/dpa3/tests/test_torchscript.py deleted file mode 100644 index 4520193c1b..0000000000 --- a/src/metatrain/experimental/dpa3/tests/test_torchscript.py +++ /dev/null @@ -1,91 +0,0 @@ -import copy - -import torch -from metatomic.torch import System - -from metatrain.experimental.dpa3 import DPA3 -from metatrain.utils.data import DatasetInfo -from metatrain.utils.data.target_info import get_energy_target_info -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists - -from . import MODEL_HYPERS - - -def test_torchscript(): - """Tests that the model can be jitted.""" - - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={ - "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) - }, - ) - - model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") - system = System( - types=torch.tensor([6, 1, 8, 7]), - positions=torch.tensor( - [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] - ), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) - - model = torch.jit.script(model) - model( - [system], - {"energy": model.outputs["energy"]}, - ) - - -def test_torchscript_save_load(tmpdir): - """Tests that the model can be jitted and saved.""" - - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={ - "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) - }, - ) - model = DPA3(MODEL_HYPERS, dataset_info) - - with tmpdir.as_cwd(): - torch.jit.save(torch.jit.script(model), "model.pt") - torch.jit.load("model.pt") - - -def test_torchscript_integers(): - """Tests that the model can be jitted when some float - parameters are instead supplied as integers.""" - - new_hypers = copy.deepcopy(MODEL_HYPERS) - new_hypers["cutoff"] = 5 - new_hypers["cutoff_width"] = 1 - - dataset_info = DatasetInfo( - length_unit="Angstrom", - atomic_types=[1, 6, 7, 8], - targets={ - "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) - }, - ) - model = DPA3(MODEL_HYPERS, dataset_info).to("cpu") - - system = System( - types=torch.tensor([6, 1, 8, 7]), - positions=torch.tensor( - [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] - ), - cell=torch.zeros(3, 3), - pbc=torch.tensor([False, False, False]), - ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) - - model = torch.jit.script(model) - model( - [system], - {"energy": model.outputs["energy"]}, - ) diff --git a/src/metatrain/utils/testing/architectures.py b/src/metatrain/utils/testing/architectures.py index 4d2dfd008b..8add5fbe9e 100644 --- a/src/metatrain/utils/testing/architectures.py +++ b/src/metatrain/utils/testing/architectures.py @@ -145,11 +145,21 @@ def per_atom(self, request: pytest.FixtureRequest) -> bool: """ return request.param - @pytest.fixture - def dataset_info_scalar(self, per_atom: bool) -> DatasetInfo: + @pytest.fixture(params=[1, 5]) + def num_subtargets(self, request: pytest.FixtureRequest) -> int: + """Fixture to provide different numbers of subtargets for + testing. + + :param request: The pytest request fixture. + :return: The number of subtargets. + """ + return request.param + + def dataset_info_scalar(self, num_subtargets: int, per_atom: bool) -> DatasetInfo: """Fixture that provides a basic ``DatasetInfo`` with a scalar target for testing. + :param num_subtargets: The number of scalars in the target. :param per_atom: Whether the target is per-atom or not. :return: A ``DatasetInfo`` instance with a scalar target. """ @@ -163,7 +173,7 @@ def dataset_info_scalar(self, per_atom: bool) -> DatasetInfo: "quantity": "scalar", "unit": "", "type": "scalar", - "num_subtargets": 5, + "num_subtargets": num_subtargets, "per_atom": per_atom, }, ) diff --git a/src/metatrain/utils/testing/output.py b/src/metatrain/utils/testing/output.py index 929d2e3d15..a3d198c16b 100644 --- a/src/metatrain/utils/testing/output.py +++ b/src/metatrain/utils/testing/output.py @@ -36,6 +36,8 @@ class OutputTests(ArchitectureTests): supports_scalar_outputs: bool = True """Whether the model supports scalar outputs.""" + supports_multiscalar_outputs: bool = True + """Whether the model supports outputs with multiple scalar subtargets.""" supports_vector_outputs: bool = True """Whether the model supports vector outputs.""" supports_spherical_outputs: bool = True @@ -132,7 +134,11 @@ def _get_output( return model([system], {k: ModelOutput(per_atom=per_atom) for k in outputs}) def test_output_scalar( - self, model_hypers: dict, dataset_info_scalar: DatasetInfo, per_atom: bool + self, + model_hypers: dict, + dataset_info_scalar: DatasetInfo, + num_subtargets: int, + per_atom: bool, ) -> None: """Tests that forward pass works for scalar outputs. @@ -140,7 +146,11 @@ def test_output_scalar( and values shape. This test is skipped if the model does not support scalar outputs, - i.e., if ``supports_scalar_outputs`` is set to ``False``. + i.e., if ``supports_scalar_outputs`` is set to ``False``. The test + is run twice, once for single scalar outputs, and once for a scalar + output with multiple subtargets. If the model does not support + multiple scalar subtargets, set ``supports_multiscalar_outputs`` + to ``False``, which will skip the test for multiple subtargets. If this test is failing, your model might: @@ -150,23 +160,29 @@ def test_output_scalar( :param model_hypers: Hyperparameters to initialize the model. :param dataset_info_scalar: Dataset information with scalar outputs. + :param num_subtargets: The number of scalars that the target contains. :param per_atom: Whether the requested outputs are per-atom or not. """ if not self.supports_scalar_outputs: pytest.skip(f"{self.architecture} does not support scalar outputs.") + if num_subtargets > 1 and not self.supports_multiscalar_outputs: + pytest.skip( + f"{self.architecture} does not support multiple scalar subtargets." + ) + outputs = self._get_output( model_hypers, dataset_info_scalar, per_atom, ["scalar"] ) if per_atom: assert outputs["scalar"].block().samples.names == ["system", "atom"] - assert outputs["scalar"].block().values.shape == (4, 5) + assert outputs["scalar"].block().values.shape == (4, num_subtargets) else: assert outputs["scalar"].block().samples.names == ["system"], ( outputs["scalar"].block().samples.names ) - assert outputs["scalar"].block().values.shape == (1, 5) + assert outputs["scalar"].block().values.shape == (1, num_subtargets) def test_output_vector( self, model_hypers: dict, dataset_info_vector: DatasetInfo, per_atom: bool diff --git a/src/metatrain/utils/testing/torchscript.py b/src/metatrain/utils/testing/torchscript.py index 6703ce336f..b4ebcc9746 100644 --- a/src/metatrain/utils/testing/torchscript.py +++ b/src/metatrain/utils/testing/torchscript.py @@ -1,6 +1,7 @@ import copy from typing import Any +import pytest import torch from metatomic.torch import System @@ -20,6 +21,9 @@ class TorchscriptTests(ArchitectureTests): that are floats. A test will set these to integers to test that TorchScript compilation works in that case.""" + supports_spherical_outputs: bool = True + """Whether the model supports spherical tensor outputs.""" + def jit_compile(self, model: ModelInterface) -> torch.jit.ScriptModule: """JIT compiles the given model. @@ -109,11 +113,17 @@ def test_torchscript_spherical( ) -> None: """Tests that there is no problem with jitting with spherical targets. + This test is skipped if the model does not support spherical outputs, + i.e., if ``supports_spherical_outputs`` is set to ``False``. + :param model_hypers: Hyperparameters to initialize the model. :param dataset_info_spherical: Dataset to initialize the model (containing spherical targets). """ + if not self.supports_spherical_outputs: + pytest.skip("Model does not support spherical outputs.") + self.test_torchscript( model_hypers=model_hypers, dataset_info=dataset_info_spherical, From f57414f5077fd486934a76dc15b1b604210781c0 Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Fri, 30 Jan 2026 12:06:42 +0100 Subject: [PATCH 17/18] Make docs build --- docs/src/architectures/dpa3.rst | 44 ------ docs/src/architectures/nanopet.rst | 133 ------------------ .../experimental/dpa3/documentation.py | 7 +- 3 files changed, 5 insertions(+), 179 deletions(-) delete mode 100644 docs/src/architectures/dpa3.rst delete mode 100644 docs/src/architectures/nanopet.rst diff --git a/docs/src/architectures/dpa3.rst b/docs/src/architectures/dpa3.rst deleted file mode 100644 index b0f999d6a1..0000000000 --- a/docs/src/architectures/dpa3.rst +++ /dev/null @@ -1,44 +0,0 @@ -.. _architecture-dpa3: - -DPA3 (experimental) -====================== - -.. warning:: - - This is an **experimental architecture**. You should not use it for anything important. - -This is an interface to the DPA3 architecture described in https://arxiv.org/abs/2506.01686 -and implemented in deepmd-kit (https://github.com/deepmodeling/deepmd-kit). - -Installation ------------- - -To install the package, you can run the following command in the root -directory of the repository: - -.. code-block:: bash - - pip install metatrain[dpa3] - -This will install the package with the DPA3 dependencies. - - -Default Hyperparameters ------------------------ - -The default hyperparameters for the DPA3 architecture are: - -.. literalinclude:: ../../../src/metatrain/experimental/dpa3/default-hypers.yaml - :language: yaml - - -Tuning Hyperparameters ----------------------- - -@littlepeachs this is where you can tell users how to tune the parameters of the model -to obtain different speed/accuracy tradeoffs - -References ----------- - -.. footbibliography:: diff --git a/docs/src/architectures/nanopet.rst b/docs/src/architectures/nanopet.rst deleted file mode 100644 index 821c93a004..0000000000 --- a/docs/src/architectures/nanopet.rst +++ /dev/null @@ -1,133 +0,0 @@ -.. _architecture-nanopet: - -NanoPET (deprecated) -====================== - -.. warning:: - - This is an **deprecated model**. You should not use it for anything important, and - support for it will be removed in future versions of metatrain. Please use the - :ref:`PET model ` instead. - -Installation ------------- - -To install the package, you can run the following command in the root -directory of the repository: - -.. code-block:: bash - - pip install metatrain[nanopet] - -This will install the package with the nanoPET dependencies. - - -Default Hyperparameters ------------------------ - -The default hyperparameters for the nanoPET model are: - -.. literalinclude:: ../../../src/metatrain/deprecated/nanopet/default-hypers.yaml - :language: yaml - - -Tuning Hyperparameters ----------------------- - -The default hyperparameters above will work well in most cases, but they -may not be optimal for your specific dataset. In general, the most important -hyperparameters to tune are (in decreasing order of importance): - -- ``cutoff``: This should be set to a value after which most of the interactions between - atoms is expected to be negligible. A lower cutoff will lead to faster models. -- ``learning_rate``: The learning rate for the neural network. This hyperparameter - controls how much the weights of the network are updated at each step of the - optimization. A larger learning rate will lead to faster training, but might cause - instability and/or divergence. -- ``batch_size``: The number of samples to use in each batch of training. This - hyperparameter controls the tradeoff between training speed and memory usage. In - general, larger batch sizes will lead to faster training, but might require more - memory. -- ``d_pet``: This hyperparameters controls width of the neural network. In general, - increasing it might lead to better accuracy, especially on larger datasets, at the - cost of increased training and evaluation time. -- ``num_gnn_layers``: The number of graph neural network layers. In general, decreasing - this hyperparameter to 1 will lead to much faster models, at the expense of accuracy. - Increasing it may or may not lead to better accuracy, depending on the dataset, at the - cost of increased training and evaluation time. -- ``num_attention_layers``: The number of attention layers in each layer of the graph - neural network. Depending on the dataset, increasing this hyperparameter might lead to - better accuracy, at the cost of increased training and evaluation time. -- ``loss``: This section describes the loss function to be used. See the - :doc:`dedicated documentation page <../advanced-concepts/loss-functions>` for more - details. -- ``long_range``: In some systems and datasets, enabling long-range Coulomb interactions - might be beneficial for the accuracy of the model and/or its physical correctness. - See below for a breakdown of the long-range section of the model hyperparameters. - - -All Hyperparameters -------------------- - -:param name: ``deprecated.nanopet`` - -model -##### - -The model-related hyperparameters are - -:param cutoff: Spherical cutoff to use for atomic environments -:param cutoff_width: Width of the shifted cosine cutoff function -:param d_pet: Width of the neural network -:param num_heads: Number of attention heads -:param num_attention_layers: Number of attention layers in each GNN layer -:param num_gnn_layers: Number of GNN layers -:param heads: The type of head ("linear" or "mlp") to use for each target (e.g. - ``heads: {"energy": "linear", "mtt::dipole": "mlp"}``). All omitted targets will use a - MLP (multi-layer perceptron) head. MLP heads consist of two hidden layers with - dimensionality ``d_pet``. -:param zbl: Whether to use the ZBL short-range repulsion as the baseline for the model -:param long_range: Parameters related to long-range interactions. ``enable``: whether - to use long-range interactions; ``use_ewald``: whether to use an Ewald calculator - (faster for smaller systems); ``smearing``: the width of the Gaussian function used - to approximate the charge distribution in Fourier space; ``kspace_resolution``: the - spatial resolution of the Fourier-space used for calculating long-range interactions; - ``interpolation_nodes``: the number of grid points used in spline - interpolation for the P3M method. - -training -######## - -The hyperparameters for training are - -:param distributed: Whether to use distributed training -:param distributed_port: Port to use for distributed training -:param batch_size: Batch size for training -:param num_epochs: Number of epochs to train for -:param learning_rate: Learning rate for the optimizer -:param scheduler_patience: Patience for the learning rate scheduler -:param scheduler_factor: Factor to reduce the learning rate by -:param log_interval: Interval at which to log training metrics -:param checkpoint_interval: Interval at which to save model checkpoints -:param scale_targets: Whether to scale the targets to have unit standard deviation - across the training set during training. -:param fixed_composition_weights: Weights for fixed atomic contributions to scalar - targets -:param per_structure_targets: Targets to calculate per-structure losses for -:param log_mae: Also logs MAEs in addition to RMSEs. -:param log_separate_blocks: Whether to log the errors each block of the targets - separately. -:param loss: The loss function to use, with the subfields described in the previous - section -:param best_model_metric: specifies the validation set metric to use to select the best - model, i.e. the model that will be saved as ``model.ckpt`` and ``model.pt`` both in - the current directory and in the checkpoint directory. The default is ``rmse_prod``, - i.e., the product of the RMSEs for each target. Other options are ``mae_prod`` and - ``loss``. -:param num_workers: Number of workers for data loading. If not provided, it is set - automatically. - -References ----------- - -.. footbibliography:: diff --git a/src/metatrain/experimental/dpa3/documentation.py b/src/metatrain/experimental/dpa3/documentation.py index 8e2cfc8a82..2d897b65c9 100644 --- a/src/metatrain/experimental/dpa3/documentation.py +++ b/src/metatrain/experimental/dpa3/documentation.py @@ -1,6 +1,9 @@ """ -DPA3 -==== +DPA3 (experimental) +====================== + +This is an interface to the DPA3 architecture described in https://arxiv.org/abs/2506.01686 +and implemented in deepmd-kit (https://github.com/deepmodeling/deepmd-kit). """ from typing import Literal, Optional From 15ae9a0329d10e885ad5fa45e01e91f409ff1a47 Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Fri, 30 Jan 2026 12:11:17 +0100 Subject: [PATCH 18/18] Fix test typo --- src/metatrain/utils/testing/architectures.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/metatrain/utils/testing/architectures.py b/src/metatrain/utils/testing/architectures.py index 8add5fbe9e..d6edd27076 100644 --- a/src/metatrain/utils/testing/architectures.py +++ b/src/metatrain/utils/testing/architectures.py @@ -154,7 +154,8 @@ def num_subtargets(self, request: pytest.FixtureRequest) -> int: :return: The number of subtargets. """ return request.param - + + @pytest.fixture def dataset_info_scalar(self, num_subtargets: int, per_atom: bool) -> DatasetInfo: """Fixture that provides a basic ``DatasetInfo`` with a scalar target for testing.