diff --git a/docs/src/dev-docs/changelog.rst b/docs/src/dev-docs/changelog.rst index ea06574502..e130f6a4d3 100644 --- a/docs/src/dev-docs/changelog.rst +++ b/docs/src/dev-docs/changelog.rst @@ -30,6 +30,8 @@ Changed - SOAP-BPNN and MCoV now use species embeddings by default, allowing for better scalability and speed. The traditional SOAP-BPNN (and associated MCoV) architecture can be accessed by setting ``legacy: True`` +- A minimum learning rate ratio has been added to PET LR scheduling. This is set to + 1e-4 of the maximum LR. Version 2025.12 - 2025-11-25 ---------------------------- @@ -53,6 +55,7 @@ Added base target (i.e. ``energy``). - The ``LLPR`` architecture now allows training LLPR ensembles by backpropagation after their creation from the LLPR covariance. This includes support for multi-GPU training. +- An experimental Muon optimizer has been added for the PET architecture. Changed ####### diff --git a/src/metatrain/experimental/flashmd/checkpoints.py b/src/metatrain/experimental/flashmd/checkpoints.py index 41376bc2b2..06e23fc028 100644 --- a/src/metatrain/experimental/flashmd/checkpoints.py +++ b/src/metatrain/experimental/flashmd/checkpoints.py @@ -11,6 +11,23 @@ def model_update_v1_v2(checkpoint: dict) -> None: target.unit = "(eV*u)^(1/2)" +def model_update_v2_v3(checkpoint: dict) -> None: + """ + Update a v2 checkpoint to v3. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if "gnn_layers" in k and ".edge_embedder." in k: + k = k.replace(".edge_embedder.", ".edge_linear.") + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + def trainer_update_v1_v2(checkpoint: dict) -> None: """ Update trainer checkpoint from version 1 to version 2. diff --git a/src/metatrain/experimental/flashmd/model.py b/src/metatrain/experimental/flashmd/model.py index 940a41f83e..d4a8d15ff0 100644 --- a/src/metatrain/experimental/flashmd/model.py +++ b/src/metatrain/experimental/flashmd/model.py @@ -46,7 +46,7 @@ class FlashMD(ModelInterface[ModelHypers]): For more information, you can refer to https://arxiv.org/abs/2505.19350. """ - __checkpoint_version__ = 2 + __checkpoint_version__ = 3 __supported_devices__ = ["cuda", "cpu"] __supported_dtypes__ = [torch.float32, torch.float64] __default_metadata__ = ModelMetadata( diff --git a/src/metatrain/experimental/flashmd/tests/checkpoints/model-v3_trainer-v3.ckpt.gz b/src/metatrain/experimental/flashmd/tests/checkpoints/model-v3_trainer-v3.ckpt.gz new file mode 100644 index 0000000000..303cac23ea Binary files /dev/null and b/src/metatrain/experimental/flashmd/tests/checkpoints/model-v3_trainer-v3.ckpt.gz differ diff --git a/src/metatrain/llpr/checkpoints.py b/src/metatrain/llpr/checkpoints.py index 9fc73c464b..196016445b 100644 --- a/src/metatrain/llpr/checkpoints.py +++ b/src/metatrain/llpr/checkpoints.py @@ -52,6 +52,23 @@ def model_update_v2_v3(checkpoint: dict) -> None: checkpoint["best_optimizer_state_dict"] = None +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v4 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if "gnn_layers" in k and ".edge_embedder." in k: + k = k.replace(".edge_embedder.", ".edge_linear.") + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + def trainer_update_v1_v2(checkpoint: dict) -> None: """ Update trainer checkpoint from version 1 to version 2. diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 6d7281d7e9..74caf565d1 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -25,7 +25,7 @@ class LLPRUncertaintyModel(ModelInterface[ModelHypers]): - __checkpoint_version__ = 3 + __checkpoint_version__ = 4 # all torch devices and dtypes are supported, if they are supported by the wrapped # the check is performed in the trainer diff --git a/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v2.ckpt.gz b/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v2.ckpt.gz new file mode 100644 index 0000000000..46bc8a25a1 Binary files /dev/null and b/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v2.ckpt.gz differ diff --git a/src/metatrain/pet/checkpoints.py b/src/metatrain/pet/checkpoints.py index 02c6b0ae5f..e83a9ebb68 100644 --- a/src/metatrain/pet/checkpoints.py +++ b/src/metatrain/pet/checkpoints.py @@ -254,6 +254,23 @@ def model_update_v9_v10(checkpoint: dict) -> None: checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine" +def model_update_v10_v11(checkpoint: dict) -> None: + """ + Update a v10 checkpoint to v11. + + :param checkpoint: The checkpoint to update. + """ + for key in ["model_state_dict", "best_model_state_dict"]: + if (state_dict := checkpoint.get(key)) is not None: + new_state_dict = {} + for k, v in state_dict.items(): + # Replacing the nn.Sequential MLP with a custom FeedForward module + if "gnn_layers" in k and ".edge_embedder." in k: + k = k.replace(".edge_embedder.", ".edge_linear.") + new_state_dict[k] = v + checkpoint[key] = new_state_dict + + ########################### # TRAINER ################# ########################### @@ -400,3 +417,18 @@ def trainer_update_v10_v11(checkpoint: dict) -> None: atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets} checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline + + +def trainer_update_v11_v12(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 11 to version 12. + + :param checkpoint: The checkpoint to update. + """ + # Adding the num_workers=0 hyperparameter if not present + if "optimizer" not in checkpoint["train_hypers"]: + if checkpoint["train_hypers"].get("weight_decay"): + optimizer = "AdamW" + else: + optimizer = "Adam" + checkpoint["train_hypers"]["optimizer"] = optimizer diff --git a/src/metatrain/pet/documentation.py b/src/metatrain/pet/documentation.py index 5884ca2bd0..8a833fc35e 100644 --- a/src/metatrain/pet/documentation.py +++ b/src/metatrain/pet/documentation.py @@ -165,7 +165,10 @@ class TrainerHypers(TypedDict): """Fraction of training steps used for learning rate warmup.""" learning_rate: float = 1e-4 """Learning rate.""" + optimizer: Literal["Adam", "AdamW", "Muon"] = "Adam" + """Optimizer to use for training the model.""" weight_decay: Optional[float] = None + """Weight decay coefficient. If None, no weight decay is used.""" log_interval: int = 1 """Interval to log metrics.""" diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 4abaae836e..778054c1de 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -48,7 +48,7 @@ class PET(ModelInterface[ModelHypers]): targets. """ - __checkpoint_version__ = 10 + __checkpoint_version__ = 11 __supported_devices__ = ["cuda", "cpu"] __supported_dtypes__ = [torch.float32, torch.float64] __default_metadata__ = ModelMetadata( diff --git a/src/metatrain/pet/modules/optimizer.py b/src/metatrain/pet/modules/optimizer.py new file mode 100644 index 0000000000..95f5ab61aa --- /dev/null +++ b/src/metatrain/pet/modules/optimizer.py @@ -0,0 +1,191 @@ +import logging +import math +from typing import Dict, Tuple, Union + +import torch +from packaging import version +from torch.optim.lr_scheduler import LambdaLR + +from ..documentation import TrainerHypers +from ..model import PET + + +def get_optimizer(model: PET, hypers: TrainerHypers) -> torch.optim.Optimizer: + """ + Get the optimizer based on the hyperparameters. + + :param model: The model to optimize. + :param hypers: The training hyperparameters. + :return: The optimizer. + """ + if hypers["weight_decay"] is None: + weight_decay = 0.0 + else: + weight_decay = hypers["weight_decay"] + lr = hypers.get("learning_rate", 1e-4) + if hypers["optimizer"].lower() == "adam": + optimizer = torch.optim.Adam( + model.parameters(), lr=lr, weight_decay=weight_decay + ) + elif hypers["optimizer"].lower() == "adamw": + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=weight_decay, + ) + elif hypers["optimizer"].lower() == "muon": + if version.parse(torch.__version__) < version.parse("2.9.1"): + raise ValueError( + f"The Muon optimizer requires PyTorch >= 2.9.1, but you have " + f"{torch.__version__}. This feature is experimental and so far " + "not well tested. Please manually update PyTorch to use the " + "Muon optimizer." + ) + logging.warning( + "Using the Muon optimizer with auxiliary AdamW for non-matrix " + "parameters. This feature is experimental and so far not well tested. " + "Please use it with caution or set the optimizer to Adam or AdamW in the " + "options.yaml." + ) + # Separate parameters into Muon and Adam groups. + # By design, Muon should only be used for the matrix-type parameters + # (i. e. those with ndim >= 2), and only for optimizing the hidden + # layers of the model (in our case, the GNN layers). All other parameters + # including biases, embeddings, and readout layers (heads) should be + # optimized with Adam or AdamW. + muon_params = [] + adam_params = [] + for n, p in model.named_parameters(): + if p.ndim >= 2 and ( + ("gnn_layers" in n and "neighbor_embedder" not in n) + or "combination_mlps" in n + ): + muon_params.append(p) + else: + adam_params.append(p) + adam_group = dict(params=adam_params, use_muon=False) + muon_group = dict(params=muon_params, use_muon=True) + optimizer = MuonWithAuxAdamW( + [muon_group, adam_group], + lr=lr, + weight_decay=weight_decay, + ) + else: + raise ValueError( + f"Unknown optimizer: {hypers['optimizer']}. Please choose Adam, " + f"AdamW or Muon." + ) + + return optimizer + + +def get_scheduler( + optimizer: torch.optim.Optimizer, + train_hypers: TrainerHypers, + steps_per_epoch: int, +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 1e-4 # hardcode minimum LR ratio + + logging.info( + f"Using cosine decay from {train_hypers['learning_rate']} to " + f"{train_hypers['learning_rate'] * min_lr_ratio} after " + f"{warmup_steps} warmup optimizer steps and {total_steps} " + "total steps." + ) + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class MuonWithAuxAdamW(torch.optim.Optimizer): + """ + Combined optimizer with Muon and AdamW for different parameter groups. + + :param param_groups: Parameter groups for the optimizer. + :param lr: Learning rate. + :param weight_decay: Weight decay. + :param momentum: Momentum for Muon. + :param eps: Epsilon for AdamW. + :param betas: Betas for AdamW. + """ + + def __init__( + self, + param_groups: list, + lr: Union[float, torch.Tensor] = 0.001, + weight_decay: float = 0.0, + momentum: float = 0.95, + eps: float = 1e-10, + betas: Tuple[float, float] = (0.9, 0.95), + ): + # Set defaults that will be merged into param_groups + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + eps=eps, + betas=betas, + ) + + # Initialize base optimizer first (this merges defaults into param_groups) + super().__init__(param_groups, defaults) + + # Now create the internal optimizers using the fully initialized param_groups + for group in self.param_groups: + assert "use_muon" in group + params = group["params"] + if group["use_muon"]: + self.muon_optimizer = torch.optim.Muon( + params, + lr=group["lr"], + momentum=group["momentum"], + ) + else: + self.adamw_optimizer = torch.optim.AdamW( + params, + lr=group["lr"], + betas=group["betas"], + eps=group["eps"], + weight_decay=group["weight_decay"], + ) + + @torch.no_grad() + def step(self) -> None: + self.muon_optimizer.step() + self.adamw_optimizer.step() + + def zero_grad(self, set_to_none: bool = True) -> None: + self.muon_optimizer.zero_grad(set_to_none=set_to_none) + self.adamw_optimizer.zero_grad(set_to_none=set_to_none) + + def load_state_dict(self, state_dict: Dict) -> None: + self.muon_optimizer.load_state_dict(state_dict["muon_optimizer"]) + self.adamw_optimizer.load_state_dict(state_dict["adamw_optimizer"]) + + def state_dict(self) -> Dict: + return { + "muon_optimizer": self.muon_optimizer.state_dict(), + "adamw_optimizer": self.adamw_optimizer.state_dict(), + } diff --git a/src/metatrain/pet/modules/transformer.py b/src/metatrain/pet/modules/transformer.py index 86af2f50fb..fe1da89c2a 100644 --- a/src/metatrain/pet/modules/transformer.py +++ b/src/metatrain/pet/modules/transformer.py @@ -386,7 +386,7 @@ def __init__( transformer_type=transformer_type, ) - self.edge_embedder = nn.Linear(4, d_model) + self.edge_linear = nn.Linear(4, d_model) if not is_first: n_merge = 3 @@ -442,7 +442,7 @@ def forward( node_embeddings = input_node_embeddings edge_embeddings = [edge_vectors, edge_distances[:, :, None]] edge_embeddings = torch.cat(edge_embeddings, dim=2) - edge_embeddings = self.edge_embedder(edge_embeddings) + edge_embeddings = self.edge_linear(edge_embeddings) if not self.is_first: neighbor_elements_embeddings = self.neighbor_embedder( diff --git a/src/metatrain/pet/tests/checkpoints/model-v11_trainer-v12.ckpt.gz b/src/metatrain/pet/tests/checkpoints/model-v11_trainer-v12.ckpt.gz new file mode 100644 index 0000000000..602947590e Binary files /dev/null and b/src/metatrain/pet/tests/checkpoints/model-v11_trainer-v12.ckpt.gz differ diff --git a/src/metatrain/pet/tests/test_basic.py b/src/metatrain/pet/tests/test_basic.py index f3df853d91..fd9d330367 100644 --- a/src/metatrain/pet/tests/test_basic.py +++ b/src/metatrain/pet/tests/test_basic.py @@ -23,6 +23,7 @@ def minimal_model_hypers(self): hypers = get_default_hypers(self.architecture)["model"] hypers = copy.deepcopy(hypers) hypers["d_pet"] = 1 + hypers["d_node"] = 1 hypers["d_head"] = 1 hypers["d_feedforward"] = 1 hypers["num_heads"] = 1 @@ -68,7 +69,13 @@ class TestTorchscript(TorchscriptTests, PETTests): class TestExported(ExportedTests, PETTests): ... -class TestTraining(TrainingTests, PETTests): ... +class TestTraining(TrainingTests, PETTests): + @pytest.fixture(params=["Adam", "Muon"]) + def default_hypers(self, request): + hypers = get_default_hypers(self.architecture) + hypers = copy.deepcopy(hypers) + hypers["training"]["optimizer"] = request.param + return hypers class TestCheckpoints(CheckpointTests, PETTests): diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index d5fa635f2c..3945ba4f56 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -1,11 +1,9 @@ import copy import logging -import math from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Union, cast import torch -from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, DistributedSampler from metatrain.utils.abc import ModelInterface, TrainerInterface @@ -40,43 +38,11 @@ from .documentation import TrainerHypers from .model import PET from .modules.finetuning import apply_finetuning_strategy - - -def get_scheduler( - optimizer: torch.optim.Optimizer, - train_hypers: TrainerHypers, - steps_per_epoch: int, -) -> LambdaLR: - """ - Get a CosineAnnealing learning-rate scheduler with warmup - - :param optimizer: The optimizer for which to create the scheduler. - :param train_hypers: The training hyperparameters. - :param steps_per_epoch: The number of steps per epoch. - :return: The learning rate scheduler. - """ - total_steps = train_hypers["num_epochs"] * steps_per_epoch - warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) - min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future - - def lr_lambda(current_step: int) -> float: - if current_step < warmup_steps: - # Linear warmup - return float(current_step) / float(max(1, warmup_steps)) - else: - # Cosine decay - progress = (current_step - warmup_steps) / float( - max(1, total_steps - warmup_steps) - ) - cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) - return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay - - scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) - return scheduler +from .modules.optimizer import get_optimizer, get_scheduler class Trainer(TrainerInterface[TrainerHypers]): - __checkpoint_version__ = 11 + __checkpoint_version__ = 12 def __init__(self, hypers: TrainerHypers) -> None: super().__init__(hypers) @@ -324,16 +290,7 @@ def train( for grad, ginfo in info["gradients"].items(): logging.info(f"\t{name}::{grad}: {ginfo}") - if self.hypers["weight_decay"] is not None: - optimizer = torch.optim.AdamW( - model.parameters(), - lr=self.hypers["learning_rate"], - weight_decay=self.hypers["weight_decay"], - ) - else: - optimizer = torch.optim.Adam( - model.parameters(), lr=self.hypers["learning_rate"] - ) + optimizer = get_optimizer(model, self.hypers) if self.optimizer_state_dict is not None and not is_finetune: # try to load the optimizer state dict, but this is only possible