Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/src/dev-docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Changed
- SOAP-BPNN and MCoV now use species embeddings by default, allowing for better
scalability and speed. The traditional SOAP-BPNN (and associated MCoV) architecture
can be accessed by setting ``legacy: True``
- A minimum learning rate ratio has been added to PET LR scheduling. This is set to
1e-4 of the maximum LR.

Version 2025.12 - 2025-11-25
----------------------------
Expand All @@ -53,6 +55,7 @@ Added
base target (i.e. ``energy``).
- The ``LLPR`` architecture now allows training LLPR ensembles by backpropagation after
their creation from the LLPR covariance. This includes support for multi-GPU training.
- An experimental Muon optimizer has been added for the PET architecture.

Changed
#######
Expand Down
17 changes: 17 additions & 0 deletions src/metatrain/experimental/flashmd/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@ def model_update_v1_v2(checkpoint: dict) -> None:
target.unit = "(eV*u)^(1/2)"


def model_update_v2_v3(checkpoint: dict) -> None:
"""
Update a v2 checkpoint to v3.

:param checkpoint: The checkpoint to update.
"""
for key in ["model_state_dict", "best_model_state_dict"]:
if (state_dict := checkpoint.get(key)) is not None:
new_state_dict = {}
for k, v in state_dict.items():
# Replacing the nn.Sequential MLP with a custom FeedForward module
if "gnn_layers" in k and ".edge_embedder." in k:
k = k.replace(".edge_embedder.", ".edge_linear.")
new_state_dict[k] = v
checkpoint[key] = new_state_dict


def trainer_update_v1_v2(checkpoint: dict) -> None:
"""
Update trainer checkpoint from version 1 to version 2.
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/flashmd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class FlashMD(ModelInterface[ModelHypers]):
For more information, you can refer to https://arxiv.org/abs/2505.19350.
"""

__checkpoint_version__ = 2
__checkpoint_version__ = 3
__supported_devices__ = ["cuda", "cpu"]
__supported_dtypes__ = [torch.float32, torch.float64]
__default_metadata__ = ModelMetadata(
Expand Down
Binary file not shown.
17 changes: 17 additions & 0 deletions src/metatrain/llpr/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ def model_update_v2_v3(checkpoint: dict) -> None:
checkpoint["best_optimizer_state_dict"] = None


def model_update_v3_v4(checkpoint: dict) -> None:
"""
Update a v4 checkpoint to v4.

:param checkpoint: The checkpoint to update.
"""
for key in ["model_state_dict", "best_model_state_dict"]:
if (state_dict := checkpoint.get(key)) is not None:
new_state_dict = {}
for k, v in state_dict.items():
# Replacing the nn.Sequential MLP with a custom FeedForward module
if "gnn_layers" in k and ".edge_embedder." in k:
k = k.replace(".edge_embedder.", ".edge_linear.")
new_state_dict[k] = v
checkpoint[key] = new_state_dict


def trainer_update_v1_v2(checkpoint: dict) -> None:
"""
Update trainer checkpoint from version 1 to version 2.
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/llpr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class LLPRUncertaintyModel(ModelInterface[ModelHypers]):
__checkpoint_version__ = 3
__checkpoint_version__ = 4

# all torch devices and dtypes are supported, if they are supported by the wrapped
# the check is performed in the trainer
Expand Down
Binary file not shown.
32 changes: 32 additions & 0 deletions src/metatrain/pet/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,23 @@ def model_update_v9_v10(checkpoint: dict) -> None:
checkpoint["model_data"]["model_hypers"]["cutoff_function"] = "Cosine"


def model_update_v10_v11(checkpoint: dict) -> None:
"""
Update a v10 checkpoint to v11.

:param checkpoint: The checkpoint to update.
"""
for key in ["model_state_dict", "best_model_state_dict"]:
if (state_dict := checkpoint.get(key)) is not None:
new_state_dict = {}
for k, v in state_dict.items():
# Replacing the nn.Sequential MLP with a custom FeedForward module
if "gnn_layers" in k and ".edge_embedder." in k:
k = k.replace(".edge_embedder.", ".edge_linear.")
new_state_dict[k] = v
checkpoint[key] = new_state_dict


###########################
# TRAINER #################
###########################
Expand Down Expand Up @@ -400,3 +417,18 @@ def trainer_update_v10_v11(checkpoint: dict) -> None:
atomic_baseline = {target_name: 0.0 for target_name in dataset_info.targets}

checkpoint["train_hypers"]["atomic_baseline"] = atomic_baseline


def trainer_update_v11_v12(checkpoint: dict) -> None:
"""
Update trainer checkpoint from version 11 to version 12.

:param checkpoint: The checkpoint to update.
"""
# Adding the num_workers=0 hyperparameter if not present
if "optimizer" not in checkpoint["train_hypers"]:
if checkpoint["train_hypers"].get("weight_decay"):
optimizer = "AdamW"
else:
optimizer = "Adam"
checkpoint["train_hypers"]["optimizer"] = optimizer
3 changes: 3 additions & 0 deletions src/metatrain/pet/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ class TrainerHypers(TypedDict):
"""Fraction of training steps used for learning rate warmup."""
learning_rate: float = 1e-4
"""Learning rate."""
optimizer: Literal["Adam", "AdamW", "Muon"] = "Adam"
"""Optimizer to use for training the model."""
weight_decay: Optional[float] = None
"""Weight decay coefficient. If None, no weight decay is used."""

log_interval: int = 1
"""Interval to log metrics."""
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PET(ModelInterface[ModelHypers]):
targets.
"""

__checkpoint_version__ = 10
__checkpoint_version__ = 11
__supported_devices__ = ["cuda", "cpu"]
__supported_dtypes__ = [torch.float32, torch.float64]
__default_metadata__ = ModelMetadata(
Expand Down
191 changes: 191 additions & 0 deletions src/metatrain/pet/modules/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import logging
import math
from typing import Dict, Tuple, Union

import torch
from packaging import version
from torch.optim.lr_scheduler import LambdaLR

from ..documentation import TrainerHypers
from ..model import PET


def get_optimizer(model: PET, hypers: TrainerHypers) -> torch.optim.Optimizer:
"""
Get the optimizer based on the hyperparameters.

:param model: The model to optimize.
:param hypers: The training hyperparameters.
:return: The optimizer.
"""
if hypers["weight_decay"] is None:
weight_decay = 0.0
else:
weight_decay = hypers["weight_decay"]
lr = hypers.get("learning_rate", 1e-4)
if hypers["optimizer"].lower() == "adam":
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, weight_decay=weight_decay
)
elif hypers["optimizer"].lower() == "adamw":
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
)
elif hypers["optimizer"].lower() == "muon":
if version.parse(torch.__version__) < version.parse("2.9.1"):
raise ValueError(
f"The Muon optimizer requires PyTorch >= 2.9.1, but you have "
f"{torch.__version__}. This feature is experimental and so far "
"not well tested. Please manually update PyTorch to use the "
"Muon optimizer."
)
logging.warning(
"Using the Muon optimizer with auxiliary AdamW for non-matrix "
"parameters. This feature is experimental and so far not well tested. "
"Please use it with caution or set the optimizer to Adam or AdamW in the "
"options.yaml."
)
# Separate parameters into Muon and Adam groups.
# By design, Muon should only be used for the matrix-type parameters
# (i. e. those with ndim >= 2), and only for optimizing the hidden
# layers of the model (in our case, the GNN layers). All other parameters
# including biases, embeddings, and readout layers (heads) should be
# optimized with Adam or AdamW.
muon_params = []
adam_params = []
for n, p in model.named_parameters():
if p.ndim >= 2 and (
("gnn_layers" in n and "neighbor_embedder" not in n)
or "combination_mlps" in n
):
muon_params.append(p)
else:
adam_params.append(p)
adam_group = dict(params=adam_params, use_muon=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want two separate learning rates for the Adam and Muon parameter groups.

If you look at the example from the README of https://github.com/KellerJordan/Muon:

from muon import MuonWithAuxAdam
hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
nonhidden_params = [*model.head.parameters(), *model.embed.parameters()]
param_groups = [
    dict(params=hidden_weights, use_muon=True,
         lr=0.02, weight_decay=0.01),
    dict(params=hidden_gains_biases+nonhidden_params, use_muon=False,
         lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01),
]
optimizer = MuonWithAuxAdam(param_groups)

the Adam LR is more what we'd normally expect but the Muon one can be pushed much higher.

Copy link
Contributor Author

@abmazitov abmazitov Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit sceptical with setting the LR values like this. I mean, then should be highly architecture-dependent, right? In the same time @sirmarcel has tested Muon for PET and noticed that it works nice even with a common LR of ~1e-3 for both Adam and Muon parameters

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ok, I noticed in my tests that I could push the Muon LR to 1e-1 even and it was still stable, but as soon as the Adam LR went above 1e-3 training diverged. But again, an extra hyperparameter is more complexity, so let's keep it simple and have one as you say for now

muon_group = dict(params=muon_params, use_muon=True)
optimizer = MuonWithAuxAdamW(
[muon_group, adam_group],
lr=lr,
weight_decay=weight_decay,
)
else:
raise ValueError(
f"Unknown optimizer: {hypers['optimizer']}. Please choose Adam, "
f"AdamW or Muon."
)

return optimizer


def get_scheduler(
optimizer: torch.optim.Optimizer,
train_hypers: TrainerHypers,
steps_per_epoch: int,
) -> LambdaLR:
"""
Get a CosineAnnealing learning-rate scheduler with warmup

:param optimizer: The optimizer for which to create the scheduler.
:param train_hypers: The training hyperparameters.
:param steps_per_epoch: The number of steps per epoch.
:return: The learning rate scheduler.
"""
total_steps = train_hypers["num_epochs"] * steps_per_epoch
warmup_steps = int(train_hypers["warmup_fraction"] * total_steps)
min_lr_ratio = 1e-4 # hardcode minimum LR ratio

logging.info(
f"Using cosine decay from {train_hypers['learning_rate']} to "
f"{train_hypers['learning_rate'] * min_lr_ratio} after "
f"{warmup_steps} warmup optimizer steps and {total_steps} "
"total steps."
)

def lr_lambda(current_step: int) -> float:
if current_step < warmup_steps:
# Linear warmup
return float(current_step) / float(max(1, warmup_steps))
else:
# Cosine decay
progress = (current_step - warmup_steps) / float(
max(1, total_steps - warmup_steps)
)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay

scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return scheduler


class MuonWithAuxAdamW(torch.optim.Optimizer):
"""
Combined optimizer with Muon and AdamW for different parameter groups.

:param param_groups: Parameter groups for the optimizer.
:param lr: Learning rate.
:param weight_decay: Weight decay.
:param momentum: Momentum for Muon.
:param eps: Epsilon for AdamW.
:param betas: Betas for AdamW.
"""

def __init__(
self,
param_groups: list,
lr: Union[float, torch.Tensor] = 0.001,
weight_decay: float = 0.0,
momentum: float = 0.95,
eps: float = 1e-10,
betas: Tuple[float, float] = (0.9, 0.95),
):
# Set defaults that will be merged into param_groups
defaults = dict(
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
eps=eps,
betas=betas,
)

# Initialize base optimizer first (this merges defaults into param_groups)
super().__init__(param_groups, defaults)

# Now create the internal optimizers using the fully initialized param_groups
for group in self.param_groups:
assert "use_muon" in group
params = group["params"]
if group["use_muon"]:
self.muon_optimizer = torch.optim.Muon(
params,
lr=group["lr"],
momentum=group["momentum"],
)
else:
self.adamw_optimizer = torch.optim.AdamW(
params,
lr=group["lr"],
betas=group["betas"],
eps=group["eps"],
weight_decay=group["weight_decay"],
)

@torch.no_grad()
def step(self) -> None:
self.muon_optimizer.step()
self.adamw_optimizer.step()

def zero_grad(self, set_to_none: bool = True) -> None:
self.muon_optimizer.zero_grad(set_to_none=set_to_none)
self.adamw_optimizer.zero_grad(set_to_none=set_to_none)

def load_state_dict(self, state_dict: Dict) -> None:
self.muon_optimizer.load_state_dict(state_dict["muon_optimizer"])
self.adamw_optimizer.load_state_dict(state_dict["adamw_optimizer"])

def state_dict(self) -> Dict:
return {
"muon_optimizer": self.muon_optimizer.state_dict(),
"adamw_optimizer": self.adamw_optimizer.state_dict(),
}
4 changes: 2 additions & 2 deletions src/metatrain/pet/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def __init__(
transformer_type=transformer_type,
)

self.edge_embedder = nn.Linear(4, d_model)
self.edge_linear = nn.Linear(4, d_model)

if not is_first:
n_merge = 3
Expand Down Expand Up @@ -442,7 +442,7 @@ def forward(
node_embeddings = input_node_embeddings
edge_embeddings = [edge_vectors, edge_distances[:, :, None]]
edge_embeddings = torch.cat(edge_embeddings, dim=2)
edge_embeddings = self.edge_embedder(edge_embeddings)
edge_embeddings = self.edge_linear(edge_embeddings)

if not self.is_first:
neighbor_elements_embeddings = self.neighbor_embedder(
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be regenerated with different hyper parameters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure - just the PET one or the others too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if you could do it for all the newly generated checkpoints!

Binary file not shown.
9 changes: 8 additions & 1 deletion src/metatrain/pet/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def minimal_model_hypers(self):
hypers = get_default_hypers(self.architecture)["model"]
hypers = copy.deepcopy(hypers)
hypers["d_pet"] = 1
hypers["d_node"] = 1
hypers["d_head"] = 1
hypers["d_feedforward"] = 1
hypers["num_heads"] = 1
Expand Down Expand Up @@ -68,7 +69,13 @@ class TestTorchscript(TorchscriptTests, PETTests):
class TestExported(ExportedTests, PETTests): ...


class TestTraining(TrainingTests, PETTests): ...
class TestTraining(TrainingTests, PETTests):
@pytest.fixture(params=["Adam", "Muon"])
def default_hypers(self, request):
hypers = get_default_hypers(self.architecture)
hypers = copy.deepcopy(hypers)
hypers["training"]["optimizer"] = request.param
return hypers


class TestCheckpoints(CheckpointTests, PETTests):
Expand Down
Loading
Loading