Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/metatrain/experimental/classifier/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class Classifier(ModelInterface[ModelHypers]):
__checkpoint_version__ = 1
__checkpoint_version__ = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frostedoyster is this needed? From the last ml-devel meeting I understood that in principle changes to an architecture's checkpoint shouldn't need a version upgrade on the wrapper.

I think what needs to be done is to simply upgrade the checkpoint file with the new file, without any renaming, but let's see what Filippo says.

(same for llpr)


# all torch devices and dtypes are supported, if they are supported by the wrapped
# model; the check is performed in the trainer
Expand Down Expand Up @@ -339,7 +339,19 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel:

@classmethod
def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict:
# Currently at version 1, no upgrades needed yet
if checkpoint["model_ckpt_version"] == 1:
# v1 -> v2: wrapped PET model added system conditioning hypers
hypers = checkpoint["wrapped_model_checkpoint"]["model_data"][
"model_hypers"
]
if "system_conditioning" not in hypers:
hypers["system_conditioning"] = False
if "max_charge" not in hypers:
hypers["max_charge"] = 10
if "max_spin" not in hypers:
hypers["max_spin"] = 10
checkpoint["model_ckpt_version"] = 2

if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__:
raise RuntimeError(
f"Unable to upgrade the checkpoint: the checkpoint is using model "
Expand Down
Binary file not shown.
16 changes: 16 additions & 0 deletions src/metatrain/llpr/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ def model_update_v2_v3(checkpoint: dict) -> None:
checkpoint["best_optimizer_state_dict"] = None


def model_update_v3_v4(checkpoint: dict) -> None:
"""
Update a v3 checkpoint to v4.

:param checkpoint: The checkpoint to update.
"""
# Upgrade the wrapped PET checkpoint to include system conditioning hypers
hypers = checkpoint["wrapped_model_checkpoint"]["model_data"]["model_hypers"]
if "system_conditioning" not in hypers:
hypers["system_conditioning"] = False
if "max_charge" not in hypers:
hypers["max_charge"] = 10
if "max_spin" not in hypers:
hypers["max_spin"] = 10


def trainer_update_v1_v2(checkpoint: dict) -> None:
"""
Update trainer checkpoint from version 1 to version 2.
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/llpr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


class LLPRUncertaintyModel(ModelInterface[ModelHypers]):
__checkpoint_version__ = 3
__checkpoint_version__ = 4

# all torch devices and dtypes are supported, if they are supported by the wrapped
# the check is performed in the trainer
Expand Down
Binary file not shown.
15 changes: 15 additions & 0 deletions src/metatrain/pet/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,21 @@ def model_update_v10_v11(checkpoint: dict) -> None:
checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0


def model_update_v11_v12(checkpoint: dict) -> None:
"""
Update a v11 checkpoint to v12.

:param checkpoint: The checkpoint to update.
"""
# Adding system conditioning hyperparameters (disabled by default)
if "system_conditioning" not in checkpoint["model_data"]["model_hypers"]:
checkpoint["model_data"]["model_hypers"]["system_conditioning"] = False
if "max_charge" not in checkpoint["model_data"]["model_hypers"]:
checkpoint["model_data"]["model_hypers"]["max_charge"] = 10
if "max_spin" not in checkpoint["model_data"]["model_hypers"]:
checkpoint["model_data"]["model_hypers"]["max_spin"] = 10


###########################
# TRAINER #################
###########################
Expand Down
11 changes: 11 additions & 0 deletions src/metatrain/pet/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ class ModelHypers(TypedDict):
"""Use ZBL potential for short-range repulsion"""
long_range: LongRangeHypers = init_with_defaults(LongRangeHypers)
"""Long-range Coulomb interactions parameters."""
system_conditioning: bool = False
"""Enable charge and spin conditioning embeddings. When enabled, per-system
charge and spin multiplicity are embedded and added to node features at each
GNN layer, allowing different predictions for the same structure under
different electronic states."""
max_charge: int = 10
"""Maximum absolute charge for the conditioning embedding table. Supports
charges in the range ``[-max_charge, +max_charge]``."""
max_spin: int = 10
"""Maximum spin multiplicity (2S+1) for the conditioning embedding table.
Supports values in the range ``[1, max_spin]``."""


class TrainerHypers(TypedDict):
Expand Down
51 changes: 50 additions & 1 deletion src/metatrain/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from . import checkpoints
from .documentation import ModelHypers
from .modules.conditioning import SystemConditioningEmbedding
from .modules.finetuning import apply_finetuning_strategy
from .modules.structures import systems_to_batch
from .modules.transformer import CartesianTransformer
Expand All @@ -48,7 +49,7 @@ class PET(ModelInterface[ModelHypers]):
targets.
"""

__checkpoint_version__ = 11
__checkpoint_version__ = 12
__supported_devices__ = ["cuda", "cpu"]
__supported_dtypes__ = [torch.float32, torch.float64]
__default_metadata__ = ModelMetadata(
Expand Down Expand Up @@ -142,6 +143,17 @@ def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None:
)
self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet)

if self.hypers.get("system_conditioning", False):
self.system_conditioning: Optional[SystemConditioningEmbedding] = (
SystemConditioningEmbedding(
d_out=self.d_node,
max_charge=self.hypers.get("max_charge", 10),
max_spin=self.hypers.get("max_spin", 10),
)
)
else:
self.system_conditioning = None

self.node_heads = torch.nn.ModuleDict()
self.edge_heads = torch.nn.ModuleDict()
self.node_last_layers = torch.nn.ModuleDict()
Expand Down Expand Up @@ -443,6 +455,23 @@ def forward(
padding_mask=padding_mask,
cutoff_factors=cutoff_factors,
)

# Extract per-system charge and spin for conditioning
if self.system_conditioning is not None:
n_systems = len(systems)
charges = torch.zeros(n_systems, dtype=torch.long, device=device)
spins = torch.ones(n_systems, dtype=torch.long, device=device)
for i, system in enumerate(systems):
if "mtt::charge" in system.known_data():
charges[i] = (
system.get_data("mtt::charge").block().values.long()
)
if "mtt::spin" in system.known_data():
spins[i] = system.get_data("mtt::spin").block().values.long()
self.system_conditioning.validate(charges, spins)
featurizer_inputs["charge"] = charges
featurizer_inputs["spin"] = spins
featurizer_inputs["system_indices"] = system_indices
node_features_list, edge_features_list = self._calculate_features(
featurizer_inputs,
use_manual_attention=use_manual_attention,
Expand Down Expand Up @@ -628,6 +657,16 @@ def _feedforward_featurization_impl(
use_manual_attention,
)

# Add system conditioning (charge/spin) to node features
if self.system_conditioning is not None:
output_node_embeddings = output_node_embeddings + (
self.system_conditioning(
inputs["charge"],
inputs["spin"],
inputs["system_indices"],
)
)

# The GNN contraction happens by reordering the messages,
# using a reversed neighbor list, so the new input message
# from atom `j` to atom `i` in on the GNN layer N+1 is a
Expand Down Expand Up @@ -687,6 +726,16 @@ def _residual_featurization_impl(
inputs["cutoff_factors"],
use_manual_attention,
)
# Add system conditioning (charge/spin) to node features
if self.system_conditioning is not None:
output_node_embeddings = output_node_embeddings + (
self.system_conditioning(
inputs["charge"],
inputs["spin"],
inputs["system_indices"],
)
)

node_features_list.append(output_node_embeddings)
edge_features_list.append(output_edge_embeddings)

Expand Down
78 changes: 78 additions & 0 deletions src/metatrain/pet/modules/conditioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""System-level conditioning embeddings for charge and spin."""

import torch


class SystemConditioningEmbedding(torch.nn.Module):
"""Embeds per-system charge and spin multiplicity into per-atom features.

Each system in a batch can have a different total charge and spin multiplicity.
These are embedded via learned lookup tables, concatenated, and projected down
to the model's feature dimension. The resulting per-system embedding is broadcast
to all atoms belonging to that system.

:param d_out: Output embedding dimension (should match d_node).
:param max_charge: Maximum absolute charge value. Supports charges in
the range ``[-max_charge, +max_charge]``.
:param max_spin: Maximum spin multiplicity (2S+1). Supports values in
the range ``[1, max_spin]``.
"""

def __init__(self, d_out: int, max_charge: int = 10, max_spin: int = 10):
super().__init__()
self.max_charge = max_charge
self.max_spin = max_spin
d_inner = d_out
self.charge_embedding = torch.nn.Embedding(2 * max_charge + 1, d_inner)
self.spin_embedding = torch.nn.Embedding(max_spin, d_inner)
gate = torch.nn.Linear(d_inner, d_out)
torch.nn.init.zeros_(gate.weight)
torch.nn.init.zeros_(gate.bias)
self.project = torch.nn.Sequential(
torch.nn.Linear(2 * d_inner, d_inner),
torch.nn.SiLU(),
gate,
)

def validate(self, charge: torch.Tensor, spin: torch.Tensor) -> None:
"""Check that charge and spin values are within the supported range.

Call this outside of ``torch.compile`` regions to get descriptive errors.

:param charge: Per-system total charge, shape ``[n_systems]``.
:param spin: Per-system spin multiplicity, shape ``[n_systems]``.
"""
if (charge < -self.max_charge).any() or (charge > self.max_charge).any():
raise ValueError(
f"charge values must be in [{-self.max_charge}, "
f"{self.max_charge}], got min={charge.min().item()}, "
f"max={charge.max().item()}. Increase max_charge in "
f"model hypers to support wider charge ranges."
)
if (spin < 1).any() or (spin > self.max_spin).any():
raise ValueError(
f"spin multiplicity values must be in [1, "
f"{self.max_spin}], got min={spin.min().item()}, "
f"max={spin.max().item()}. Increase max_spin in "
f"model hypers to support higher spin multiplicities."
)

def forward(
self,
charge: torch.Tensor,
spin: torch.Tensor,
system_indices: torch.Tensor,
) -> torch.Tensor:
"""Compute per-atom conditioning features from per-system charge and spin.

:param charge: Per-system total charge, shape ``[n_systems]``, integer.
:param spin: Per-system spin multiplicity (2S+1), shape ``[n_systems]``,
integer >= 1.
:param system_indices: Maps each atom to its system index,
shape ``[n_atoms]``.
:return: Per-atom conditioning features, shape ``[n_atoms, d_out]``.
"""
c_emb = self.charge_embedding(charge + self.max_charge) # [n_systems, d_out]
s_emb = self.spin_embedding(spin - 1) # [n_systems, d_out]
system_emb = self.project(torch.cat([c_emb, s_emb], dim=-1)) # [n_systems, d]
return system_emb[system_indices] # [n_atoms, d_out]
Binary file not shown.
Loading
Loading