From f81dfbddd9787b4d7e26d4ab05a6b4802776500e Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Sat, 7 Feb 2026 23:33:55 +0100 Subject: [PATCH 1/9] add charge and spin to dataset --- pyproject.toml | 4 +- src/metatrain/pet/documentation.py | 11 + src/metatrain/pet/model.py | 48 +++++ src/metatrain/pet/modules/conditioning.py | 64 ++++++ src/metatrain/pet/tests/test_conditioning.py | 213 +++++++++++++++++++ src/metatrain/utils/data/dataset.py | 63 +++++- src/metatrain/utils/data/get_dataset.py | 6 +- 7 files changed, 406 insertions(+), 3 deletions(-) create mode 100644 src/metatrain/pet/modules/conditioning.py create mode 100644 src/metatrain/pet/tests/test_conditioning.py diff --git a/pyproject.toml b/pyproject.toml index 074e0fefb0..9b3d1f21ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,9 @@ filterwarnings = [ # Multi-threaded tests clash with multi-process data-loading "ignore:This process \\(pid=\\d+\\) is multi-threaded, use of fork\\(\\) may lead to deadlocks in the child.:DeprecationWarning", # MACE warning with newer versions of pytorch (because they use e3nn==0.4.4) - "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning" + "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning", + # torch.jit.script deprecation in newer PyTorch versions + "ignore:`torch.jit.script` is deprecated:DeprecationWarning" ] addopts = ["-p", "mtt_plugin"] pythonpath = "src/metatrain/utils/testing" diff --git a/src/metatrain/pet/documentation.py b/src/metatrain/pet/documentation.py index 8eaec9c03e..0d72f0eb38 100644 --- a/src/metatrain/pet/documentation.py +++ b/src/metatrain/pet/documentation.py @@ -155,6 +155,17 @@ class ModelHypers(TypedDict): """Use ZBL potential for short-range repulsion""" long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) """Long-range Coulomb interactions parameters.""" + system_conditioning: bool = False + """Enable charge and spin conditioning embeddings. When enabled, per-system + charge and spin multiplicity are embedded and added to node features at each + GNN layer, allowing different predictions for the same structure under + different electronic states.""" + max_charge: int = 10 + """Maximum absolute charge for the conditioning embedding table. Supports + charges in the range ``[-max_charge, +max_charge]``.""" + max_spin: int = 10 + """Maximum spin multiplicity (2S+1) for the conditioning embedding table. + Supports values in the range ``[1, max_spin]``.""" class TrainerHypers(TypedDict): diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 757e778140..177dfd6700 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -28,6 +28,7 @@ from . import checkpoints from .documentation import ModelHypers +from .modules.conditioning import SystemConditioningEmbedding from .modules.finetuning import apply_finetuning_strategy from .modules.structures import systems_to_batch from .modules.transformer import CartesianTransformer @@ -142,6 +143,17 @@ def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: ) self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + if self.hypers.get("system_conditioning", False): + self.system_conditioning: Optional[SystemConditioningEmbedding] = ( + SystemConditioningEmbedding( + d_out=self.d_pet, + max_charge=self.hypers.get("max_charge", 10), + max_spin=self.hypers.get("max_spin", 10), + ) + ) + else: + self.system_conditioning = None + self.node_heads = torch.nn.ModuleDict() self.edge_heads = torch.nn.ModuleDict() self.node_last_layers = torch.nn.ModuleDict() @@ -443,6 +455,22 @@ def forward( padding_mask=padding_mask, cutoff_factors=cutoff_factors, ) + + # Extract per-system charge and spin for conditioning + if self.system_conditioning is not None: + n_systems = len(systems) + charges = torch.zeros(n_systems, dtype=torch.long, device=device) + spins = torch.ones(n_systems, dtype=torch.long, device=device) + for i, system in enumerate(systems): + if "mtt::charge" in system.known_data(): + charges[i] = ( + system.get_data("mtt::charge").block().values.long() + ) + if "mtt::spin" in system.known_data(): + spins[i] = system.get_data("mtt::spin").block().values.long() + featurizer_inputs["charge"] = charges + featurizer_inputs["spin"] = spins + featurizer_inputs["system_indices"] = system_indices node_features_list, edge_features_list = self._calculate_features( featurizer_inputs, use_manual_attention=use_manual_attention, @@ -628,6 +656,16 @@ def _feedforward_featurization_impl( use_manual_attention, ) + # Add system conditioning (charge/spin) to node features + if self.system_conditioning is not None: + output_node_embeddings = output_node_embeddings + ( + self.system_conditioning( + inputs["charge"], + inputs["spin"], + inputs["system_indices"], + ) + ) + # The GNN contraction happens by reordering the messages, # using a reversed neighbor list, so the new input message # from atom `j` to atom `i` in on the GNN layer N+1 is a @@ -687,6 +725,16 @@ def _residual_featurization_impl( inputs["cutoff_factors"], use_manual_attention, ) + # Add system conditioning (charge/spin) to node features + if self.system_conditioning is not None: + output_node_embeddings = output_node_embeddings + ( + self.system_conditioning( + inputs["charge"], + inputs["spin"], + inputs["system_indices"], + ) + ) + node_features_list.append(output_node_embeddings) edge_features_list.append(output_edge_embeddings) diff --git a/src/metatrain/pet/modules/conditioning.py b/src/metatrain/pet/modules/conditioning.py new file mode 100644 index 0000000000..3de8336db5 --- /dev/null +++ b/src/metatrain/pet/modules/conditioning.py @@ -0,0 +1,64 @@ +"""System-level conditioning embeddings for charge and spin.""" + +import torch + + +class SystemConditioningEmbedding(torch.nn.Module): + """Embeds per-system charge and spin multiplicity into per-atom features. + + Each system in a batch can have a different total charge and spin multiplicity. + These are embedded via learned lookup tables, concatenated, and projected down + to the model's feature dimension. The resulting per-system embedding is broadcast + to all atoms belonging to that system. + + :param d_out: Output embedding dimension (should match d_pet). + :param max_charge: Maximum absolute charge value. Supports charges in + the range ``[-max_charge, +max_charge]``. + :param max_spin: Maximum spin multiplicity (2S+1). Supports values in + the range ``[1, max_spin]``. + """ + + def __init__(self, d_out: int, max_charge: int = 10, max_spin: int = 10): + super().__init__() + self.max_charge = max_charge + self.max_spin = max_spin + self.charge_embedding = torch.nn.Embedding(2 * max_charge + 1, d_out) + self.spin_embedding = torch.nn.Embedding(max_spin, d_out) + self.project = torch.nn.Sequential( + torch.nn.Linear(2 * d_out, d_out), + torch.nn.SiLU(), + ) + + def forward( + self, + charge: torch.Tensor, + spin: torch.Tensor, + system_indices: torch.Tensor, + ) -> torch.Tensor: + """Compute per-atom conditioning features from per-system charge and spin. + + :param charge: Per-system total charge, shape ``[n_systems]``, integer. + :param spin: Per-system spin multiplicity (2S+1), shape ``[n_systems]``, + integer >= 1. + :param system_indices: Maps each atom to its system index, + shape ``[n_atoms]``. + :return: Per-atom conditioning features, shape ``[n_atoms, d_out]``. + """ + if (charge < -self.max_charge).any() or (charge > self.max_charge).any(): + raise ValueError( + f"charge values must be in [{-self.max_charge}, " + f"{self.max_charge}], got min={charge.min().item()}, " + f"max={charge.max().item()}. Increase max_charge in " + f"model hypers to support wider charge ranges." + ) + if (spin < 1).any() or (spin > self.max_spin).any(): + raise ValueError( + f"spin multiplicity values must be in [1, " + f"{self.max_spin}], got min={spin.min().item()}, " + f"max={spin.max().item()}. Increase max_spin in " + f"model hypers to support higher spin multiplicities." + ) + c_emb = self.charge_embedding(charge + self.max_charge) # [n_systems, d_out] + s_emb = self.spin_embedding(spin - 1) # [n_systems, d_out] + system_emb = self.project(torch.cat([c_emb, s_emb], dim=-1)) # [n_systems, d] + return system_emb[system_indices] # [n_atoms, d_out] diff --git a/src/metatrain/pet/tests/test_conditioning.py b/src/metatrain/pet/tests/test_conditioning.py new file mode 100644 index 0000000000..10d42b6e72 --- /dev/null +++ b/src/metatrain/pet/tests/test_conditioning.py @@ -0,0 +1,213 @@ +"""Tests for charge and spin conditioning embeddings in PET.""" + +import copy + +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatomic.torch import ModelOutput, System + +from metatrain.pet import PET +from metatrain.pet.modules.conditioning import SystemConditioningEmbedding +from metatrain.utils.architectures import get_default_hypers +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + + +def _small_hypers(system_conditioning=True, **kwargs): + """Return minimal PET hypers with system conditioning.""" + hypers = copy.deepcopy(get_default_hypers("pet")["model"]) + hypers["d_pet"] = 8 + hypers["d_head"] = 8 + hypers["d_node"] = 8 + hypers["d_feedforward"] = 8 + hypers["num_heads"] = 1 + hypers["num_attention_layers"] = 1 + hypers["num_gnn_layers"] = 1 + hypers["system_conditioning"] = system_conditioning + hypers.update(kwargs) + return hypers + + +def _dataset_info(): + return DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6], + targets={ + "energy": get_energy_target_info( + "energy", {"quantity": "energy", "unit": "eV"} + ) + }, + ) + + +def _make_scalar_tmap(value: int, property_name: str = "value") -> TensorMap: + """Create a scalar TensorMap with a single float value (System requires float).""" + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[float(value)]]), + samples=Labels("system", torch.tensor([[0]])), + components=[], + properties=Labels(property_name, torch.tensor([[0]])), + ) + ], + ) + + +def _make_system(model, charge=None, spin=None): + """Create a simple 2-atom system with optional charge/spin.""" + system = System( + types=torch.tensor([6, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + if charge is not None: + system.add_data("mtt::charge", _make_scalar_tmap(charge, "charge")) + if spin is not None: + system.add_data("mtt::spin", _make_scalar_tmap(spin, "spin")) + return get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + +def test_conditioning_shapes(): + """SystemConditioningEmbedding produces [n_atoms, d_out].""" + d_out = 16 + module = SystemConditioningEmbedding(d_out=d_out, max_charge=5, max_spin=5) + + charge = torch.tensor([0, 2]) # 2 systems + spin = torch.tensor([1, 3]) # 2 systems + system_indices = torch.tensor([0, 0, 0, 1, 1]) # 5 atoms total + + out = module(charge, spin, system_indices) + assert out.shape == (5, d_out) + assert out.dtype == torch.float32 + + +def test_conditioning_changes_output(): + """Same structure with different charges should produce different predictions.""" + hypers = _small_hypers() + model = PET(hypers, _dataset_info()) + model.eval() + + system_neutral = _make_system(model, charge=0, spin=1) + system_charged = _make_system(model, charge=2, spin=1) + + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result_neutral = model([system_neutral], outputs) + result_charged = model([system_charged], outputs) + + e_neutral = result_neutral["energy"].block().values + e_charged = result_charged["energy"].block().values + assert not torch.allclose(e_neutral, e_charged), ( + "Different charges should produce different energies" + ) + + +def test_conditioning_disabled_unchanged(): + """With system_conditioning=False, no conditioning module should exist.""" + hypers_off = _small_hypers(system_conditioning=False) + model = PET(hypers_off, _dataset_info()) + + assert model.system_conditioning is None + + # Model should still run fine + model.eval() + system = _make_system(model) + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result = model([system], outputs) + assert "energy" in result + + +def test_conditioning_gradients_flow(): + """Gradients should flow through the conditioning embeddings.""" + module = SystemConditioningEmbedding(d_out=8, max_charge=5, max_spin=5) + + charge = torch.tensor([1]) + spin = torch.tensor([2]) + system_indices = torch.tensor([0, 0]) + + out = module(charge, spin, system_indices) + loss = out.sum() + loss.backward() + + assert module.charge_embedding.weight.grad is not None + assert module.spin_embedding.weight.grad is not None + assert module.project[0].weight.grad is not None + + +def test_conditioning_batch_independence(): + """Changing charge of one system in a batch should not affect others.""" + hypers = _small_hypers() + model = PET(hypers, _dataset_info()) + model.eval() + + system_a = _make_system(model, charge=0, spin=1) + system_b_v1 = _make_system(model, charge=1, spin=1) + system_b_v2 = _make_system(model, charge=3, spin=2) + + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result_v1 = model([system_a, system_b_v1], outputs) + result_v2 = model([system_a, system_b_v2], outputs) + + # Energy of system_a should be the same in both batches + e_a_v1 = result_v1["energy"].block().values[0] + e_a_v2 = result_v2["energy"].block().values[0] + torch.testing.assert_close(e_a_v1, e_a_v2) + + # Energy of system_b should differ between batches + e_b_v1 = result_v1["energy"].block().values[1] + e_b_v2 = result_v2["energy"].block().values[1] + assert not torch.allclose(e_b_v1, e_b_v2) + + +def test_conditioning_default_values(): + """Systems without explicit charge/spin should use defaults (charge=0, spin=1).""" + hypers = _small_hypers() + model = PET(hypers, _dataset_info()) + model.eval() + + # System with no charge/spin data (should default to charge=0, spin=1) + system_default = _make_system(model) + # System with explicit charge=0, spin=1 + system_explicit = _make_system(model, charge=0, spin=1) + + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result_default = model([system_default], outputs) + result_explicit = model([system_explicit], outputs) + + e_default = result_default["energy"].block().values + e_explicit = result_explicit["energy"].block().values + torch.testing.assert_close(e_default, e_explicit) + + +def test_conditioning_out_of_range(): + """Charges or spins outside the supported range raise ValueError.""" + module = SystemConditioningEmbedding(d_out=8, max_charge=3, max_spin=4) + system_indices = torch.tensor([0]) + + # charge too positive + with pytest.raises(ValueError, match=r"charge values must be in \[-3, 3\]"): + module(torch.tensor([5]), torch.tensor([1]), system_indices) + + # charge too negative + with pytest.raises(ValueError, match=r"charge values must be in \[-3, 3\]"): + module(torch.tensor([-4]), torch.tensor([1]), system_indices) + + # spin too high + with pytest.raises( + ValueError, match=r"spin multiplicity values must be in \[1, 4\]" + ): + module(torch.tensor([0]), torch.tensor([5]), system_indices) + + # spin too low (0 is invalid, minimum is 1) + with pytest.raises( + ValueError, match=r"spin multiplicity values must be in \[1, 4\]" + ): + module(torch.tensor([0]), torch.tensor([0]), system_indices) diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 9c157f52fd..c090e531d9 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -979,9 +979,19 @@ class MemmapDataset(TorchDataset): :param path: Path to the directory containing the dataset. :param target_options: Dictionary containing the target configurations, in the format corresponding to metatrain yaml input files. + :param system_options: Optional dictionary with system-level options. Supported + keys are ``charge`` (with sub-key ``key`` specifying the ``.bin`` filename + stem) and ``spin`` (same format). These are loaded as per-system scalars and + attached to each ``System`` via ``add_data("mtt::charge", ...)`` / + ``add_data("mtt::spin", ...)``. """ - def __init__(self, path: Union[str, Path], target_options: Dict[str, Any]) -> None: + def __init__( + self, + path: Union[str, Path], + target_options: Dict[str, Any], + system_options: Optional[Dict[str, Any]] = None, + ) -> None: path = Path(path) self.target_config = target_options self.sample_class = namedtuple( @@ -999,6 +1009,21 @@ def __init__(self, path: Union[str, Path], target_options: Dict[str, Any]) -> No path / "momenta.bin", (self.na[-1], 3), "float32", mode="r" ) + # Optional per-system charge and spin arrays + self.charge_array: Optional[MemmapArray] = None + self.spin_array: Optional[MemmapArray] = None + if system_options is not None: + if "charge" in system_options: + charge_key = system_options["charge"]["key"] + self.charge_array = MemmapArray( + path / f"{charge_key}.bin", (self.ns,), "float32", mode="r" + ) + if "spin" in system_options: + spin_key = system_options["spin"]["key"] + self.spin_array = MemmapArray( + path / f"{spin_key}.bin", (self.ns,), "float32", mode="r" + ) + # Register arrays pointing to the targets self.target_arrays = {} for target_key, single_target_options in target_options.items(): @@ -1075,6 +1100,42 @@ def __getitem__(self, i: int) -> Any: pbc=torch.logical_not(torch.all(c == 0.0, dim=1)), ) + # Attach optional per-system charge and spin data + if self.charge_array is not None: + system.add_data( + "mtt::charge", + TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[float(self.charge_array[i])]]), + samples=Labels( + "system", torch.tensor([[i]], dtype=torch.int32) + ), + components=[], + properties=Labels("charge", torch.tensor([[0]])), + ) + ], + ), + ) + if self.spin_array is not None: + system.add_data( + "mtt::spin", + TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[float(self.spin_array[i])]]), + samples=Labels( + "system", torch.tensor([[i]], dtype=torch.int32) + ), + components=[], + properties=Labels("spin", torch.tensor([[0]])), + ) + ], + ), + ) + target_dict = {} for target_key, target_options in self.target_config.items(): target_array = self.target_arrays[target_key] diff --git a/src/metatrain/utils/data/get_dataset.py b/src/metatrain/utils/data/get_dataset.py index 16a964a1d5..6d94877628 100644 --- a/src/metatrain/utils/data/get_dataset.py +++ b/src/metatrain/utils/data/get_dataset.py @@ -38,7 +38,11 @@ def get_dataset( if "extra_data" in options: extra_data_info_dictionary = dataset.get_target_info(options["extra_data"]) elif Path(options["systems"]["read_from"]).is_dir(): # memmap dataset - dataset = MemmapDataset(options["systems"]["read_from"], options["targets"]) + dataset = MemmapDataset( + options["systems"]["read_from"], + options["targets"], + system_options=options["systems"], + ) target_info_dictionary = dataset.get_target_info() else: systems = read_systems( From d6bfd3ec20e742b5fa179c54dc36cad5240b859a Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Sun, 8 Feb 2026 00:26:42 +0100 Subject: [PATCH 2/9] add spin/charge hyperparameters to systemhypers --- src/metatrain/share/base_hypers.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/metatrain/share/base_hypers.py b/src/metatrain/share/base_hypers.py index 474ab8bcb6..e6abb3c302 100644 --- a/src/metatrain/share/base_hypers.py +++ b/src/metatrain/share/base_hypers.py @@ -31,6 +31,14 @@ class ArchitectureBaseHypers(TypedDict): """ +@with_config(ConfigDict(extra="forbid", strict=True)) +class SystemDataKeyHypers(TypedDict): + """Reference to a per-system scalar stored in a memmap ``.bin`` file.""" + + key: str + """Filename stem of the ``.bin`` file (e.g. ``q`` reads ``q.bin``).""" + + @with_config(ConfigDict(extra="forbid", strict=True)) class SystemsHypers(TypedDict): """Hyperparameters for the systems in the dataset.""" @@ -51,6 +59,12 @@ class SystemsHypers(TypedDict): The list of possible length units is available `here `_.""" + charge: NotRequired[SystemDataKeyHypers] + """Per-system total charge stored in a memmap ``.bin`` file. Only used + with memmap datasets and PET's ``system_conditioning`` feature.""" + spin: NotRequired[SystemDataKeyHypers] + """Per-system spin multiplicity (2S+1) stored in a memmap ``.bin`` file. + Only used with memmap datasets and PET's ``system_conditioning`` feature.""" @with_config(ConfigDict(extra="forbid", strict=True)) From 6e619e40cdd8c69525a8f593c6ed6718eb6ae1dd Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Sun, 8 Feb 2026 00:31:23 +0100 Subject: [PATCH 3/9] add spin/charge hyperparameters to systemhypers --- src/metatrain/utils/data/dataset.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index c090e531d9..fc44af241e 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -1108,7 +1108,10 @@ def __getitem__(self, i: int) -> Any: keys=Labels.single(), blocks=[ TensorBlock( - values=torch.tensor([[float(self.charge_array[i])]]), + values=torch.tensor( + [[float(self.charge_array[i])]], + dtype=torch.float64, + ), samples=Labels( "system", torch.tensor([[i]], dtype=torch.int32) ), @@ -1125,7 +1128,10 @@ def __getitem__(self, i: int) -> Any: keys=Labels.single(), blocks=[ TensorBlock( - values=torch.tensor([[float(self.spin_array[i])]]), + values=torch.tensor( + [[float(self.spin_array[i])]], + dtype=torch.float64, + ), samples=Labels( "system", torch.tensor([[i]], dtype=torch.int32) ), From 831530a9037b57573f0dec4552f072ccc6b782ec Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Sun, 8 Feb 2026 17:35:03 +0100 Subject: [PATCH 4/9] small fixes --- src/metatrain/pet/model.py | 3 +- src/metatrain/pet/modules/conditioning.py | 48 +++++++++++++------- src/metatrain/pet/tests/test_conditioning.py | 43 +++++++++++++++--- 3 files changed, 69 insertions(+), 25 deletions(-) diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 177dfd6700..49ff262473 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -146,7 +146,7 @@ def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: if self.hypers.get("system_conditioning", False): self.system_conditioning: Optional[SystemConditioningEmbedding] = ( SystemConditioningEmbedding( - d_out=self.d_pet, + d_out=self.d_node, max_charge=self.hypers.get("max_charge", 10), max_spin=self.hypers.get("max_spin", 10), ) @@ -468,6 +468,7 @@ def forward( ) if "mtt::spin" in system.known_data(): spins[i] = system.get_data("mtt::spin").block().values.long() + self.system_conditioning.validate(charges, spins) featurizer_inputs["charge"] = charges featurizer_inputs["spin"] = spins featurizer_inputs["system_indices"] = system_indices diff --git a/src/metatrain/pet/modules/conditioning.py b/src/metatrain/pet/modules/conditioning.py index 3de8336db5..f305e40f84 100644 --- a/src/metatrain/pet/modules/conditioning.py +++ b/src/metatrain/pet/modules/conditioning.py @@ -11,7 +11,7 @@ class SystemConditioningEmbedding(torch.nn.Module): to the model's feature dimension. The resulting per-system embedding is broadcast to all atoms belonging to that system. - :param d_out: Output embedding dimension (should match d_pet). + :param d_out: Output embedding dimension (should match d_node). :param max_charge: Maximum absolute charge value. Supports charges in the range ``[-max_charge, +max_charge]``. :param max_spin: Maximum spin multiplicity (2S+1). Supports values in @@ -22,27 +22,25 @@ def __init__(self, d_out: int, max_charge: int = 10, max_spin: int = 10): super().__init__() self.max_charge = max_charge self.max_spin = max_spin - self.charge_embedding = torch.nn.Embedding(2 * max_charge + 1, d_out) - self.spin_embedding = torch.nn.Embedding(max_spin, d_out) + d_inner = d_out + self.charge_embedding = torch.nn.Embedding(2 * max_charge + 1, d_inner) + self.spin_embedding = torch.nn.Embedding(max_spin, d_inner) + gate = torch.nn.Linear(d_inner, d_out) + torch.nn.init.zeros_(gate.weight) + torch.nn.init.zeros_(gate.bias) self.project = torch.nn.Sequential( - torch.nn.Linear(2 * d_out, d_out), + torch.nn.Linear(2 * d_inner, d_inner), torch.nn.SiLU(), + gate, ) - def forward( - self, - charge: torch.Tensor, - spin: torch.Tensor, - system_indices: torch.Tensor, - ) -> torch.Tensor: - """Compute per-atom conditioning features from per-system charge and spin. + def validate(self, charge: torch.Tensor, spin: torch.Tensor) -> None: + """Check that charge and spin values are within the supported range. - :param charge: Per-system total charge, shape ``[n_systems]``, integer. - :param spin: Per-system spin multiplicity (2S+1), shape ``[n_systems]``, - integer >= 1. - :param system_indices: Maps each atom to its system index, - shape ``[n_atoms]``. - :return: Per-atom conditioning features, shape ``[n_atoms, d_out]``. + Call this outside of ``torch.compile`` regions to get descriptive errors. + + :param charge: Per-system total charge, shape ``[n_systems]``. + :param spin: Per-system spin multiplicity, shape ``[n_systems]``. """ if (charge < -self.max_charge).any() or (charge > self.max_charge).any(): raise ValueError( @@ -58,6 +56,22 @@ def forward( f"max={spin.max().item()}. Increase max_spin in " f"model hypers to support higher spin multiplicities." ) + + def forward( + self, + charge: torch.Tensor, + spin: torch.Tensor, + system_indices: torch.Tensor, + ) -> torch.Tensor: + """Compute per-atom conditioning features from per-system charge and spin. + + :param charge: Per-system total charge, shape ``[n_systems]``, integer. + :param spin: Per-system spin multiplicity (2S+1), shape ``[n_systems]``, + integer >= 1. + :param system_indices: Maps each atom to its system index, + shape ``[n_atoms]``. + :return: Per-atom conditioning features, shape ``[n_atoms, d_out]``. + """ c_emb = self.charge_embedding(charge + self.max_charge) # [n_systems, d_out] s_emb = self.spin_embedding(spin - 1) # [n_systems, d_out] system_emb = self.project(torch.cat([c_emb, s_emb], dim=-1)) # [n_systems, d] diff --git a/src/metatrain/pet/tests/test_conditioning.py b/src/metatrain/pet/tests/test_conditioning.py index 10d42b6e72..3a97d31732 100644 --- a/src/metatrain/pet/tests/test_conditioning.py +++ b/src/metatrain/pet/tests/test_conditioning.py @@ -86,11 +86,41 @@ def test_conditioning_shapes(): assert out.dtype == torch.float32 +def test_conditioning_different_d_node_d_pet(): + """Conditioning works when d_node != d_pet (the common case).""" + hypers = _small_hypers(d_pet=8, d_node=16) + model = PET(hypers, _dataset_info()) + model.eval() + + system = _make_system(model, charge=1, spin=2) + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result = model([system], outputs) + assert "energy" in result + + +def _train_steps(model, n_steps=10): + """Do a few optimizer steps to break the zero-init of the conditioning gate.""" + model.train() + for _ in range(n_steps): + system = _make_system(model, charge=2, spin=3) + outputs = {"energy": ModelOutput(per_atom=False)} + result = model([system], outputs) + loss = result["energy"].block().values.sum() + loss.backward() + with torch.no_grad(): + for p in model.parameters(): + if p.grad is not None: + p -= 0.001 * p.grad + p.grad.zero_() + model.eval() + + def test_conditioning_changes_output(): """Same structure with different charges should produce different predictions.""" hypers = _small_hypers() model = PET(hypers, _dataset_info()) - model.eval() + _train_steps(model) system_neutral = _make_system(model, charge=0, spin=1) system_charged = _make_system(model, charge=2, spin=1) @@ -144,7 +174,7 @@ def test_conditioning_batch_independence(): """Changing charge of one system in a batch should not affect others.""" hypers = _small_hypers() model = PET(hypers, _dataset_info()) - model.eval() + _train_steps(model) system_a = _make_system(model, charge=0, spin=1) system_b_v1 = _make_system(model, charge=1, spin=1) @@ -190,24 +220,23 @@ def test_conditioning_default_values(): def test_conditioning_out_of_range(): """Charges or spins outside the supported range raise ValueError.""" module = SystemConditioningEmbedding(d_out=8, max_charge=3, max_spin=4) - system_indices = torch.tensor([0]) # charge too positive with pytest.raises(ValueError, match=r"charge values must be in \[-3, 3\]"): - module(torch.tensor([5]), torch.tensor([1]), system_indices) + module.validate(torch.tensor([5]), torch.tensor([1])) # charge too negative with pytest.raises(ValueError, match=r"charge values must be in \[-3, 3\]"): - module(torch.tensor([-4]), torch.tensor([1]), system_indices) + module.validate(torch.tensor([-4]), torch.tensor([1])) # spin too high with pytest.raises( ValueError, match=r"spin multiplicity values must be in \[1, 4\]" ): - module(torch.tensor([0]), torch.tensor([5]), system_indices) + module.validate(torch.tensor([0]), torch.tensor([5])) # spin too low (0 is invalid, minimum is 1) with pytest.raises( ValueError, match=r"spin multiplicity values must be in \[1, 4\]" ): - module(torch.tensor([0]), torch.tensor([0]), system_indices) + module.validate(torch.tensor([0]), torch.tensor([0])) From b0ca0d8a591ee7e1aeb4668417c2b38f0b0dd305 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Mon, 9 Feb 2026 11:21:30 +0100 Subject: [PATCH 5/9] add new checkpoint --- pyproject.toml | 2 -- src/metatrain/pet/checkpoints.py | 15 +++++++++++++++ src/metatrain/pet/model.py | 2 +- src/metatrain/pet/tests/test_conditioning.py | 2 +- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9b3d1f21ae..e31b2248bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,8 +177,6 @@ filterwarnings = [ "ignore:This process \\(pid=\\d+\\) is multi-threaded, use of fork\\(\\) may lead to deadlocks in the child.:DeprecationWarning", # MACE warning with newer versions of pytorch (because they use e3nn==0.4.4) "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning", - # torch.jit.script deprecation in newer PyTorch versions - "ignore:`torch.jit.script` is deprecated:DeprecationWarning" ] addopts = ["-p", "mtt_plugin"] pythonpath = "src/metatrain/utils/testing" diff --git a/src/metatrain/pet/checkpoints.py b/src/metatrain/pet/checkpoints.py index 020b1e6590..ef1967e2d2 100644 --- a/src/metatrain/pet/checkpoints.py +++ b/src/metatrain/pet/checkpoints.py @@ -265,6 +265,21 @@ def model_update_v10_v11(checkpoint: dict) -> None: checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 +def model_update_v11_v12(checkpoint: dict) -> None: + """ + Update a v11 checkpoint to v12. + + :param checkpoint: The checkpoint to update. + """ + # Adding system conditioning hyperparameters (disabled by default) + if "system_conditioning" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["system_conditioning"] = False + if "max_charge" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["max_charge"] = 10 + if "max_spin" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["max_spin"] = 10 + + ########################### # TRAINER ################# ########################### diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 49ff262473..0ad34e4a31 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -49,7 +49,7 @@ class PET(ModelInterface[ModelHypers]): targets. """ - __checkpoint_version__ = 11 + __checkpoint_version__ = 12 __supported_devices__ = ["cuda", "cpu"] __supported_dtypes__ = [torch.float32, torch.float64] __default_metadata__ = ModelMetadata( diff --git a/src/metatrain/pet/tests/test_conditioning.py b/src/metatrain/pet/tests/test_conditioning.py index 3a97d31732..daa825ba45 100644 --- a/src/metatrain/pet/tests/test_conditioning.py +++ b/src/metatrain/pet/tests/test_conditioning.py @@ -111,7 +111,7 @@ def _train_steps(model, n_steps=10): with torch.no_grad(): for p in model.parameters(): if p.grad is not None: - p -= 0.001 * p.grad + p -= 0.01 * p.grad p.grad.zero_() model.eval() From f6ff254c36825c1c72f2749409515695c5d92d80 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Mon, 9 Feb 2026 10:28:10 +0000 Subject: [PATCH 6/9] new checkpoint file --- .../checkpoints/model-v12_trainer-v12.ckpt.gz | Bin 0 -> 18256 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/metatrain/pet/tests/checkpoints/model-v12_trainer-v12.ckpt.gz diff --git a/src/metatrain/pet/tests/checkpoints/model-v12_trainer-v12.ckpt.gz b/src/metatrain/pet/tests/checkpoints/model-v12_trainer-v12.ckpt.gz new file mode 100644 index 0000000000000000000000000000000000000000..d2a289c9683efd57a3fa8751327535f8d6357ebb GIT binary patch literal 18256 zcmZ6ybzGFq7e0J1Eg;?9B_XXK-AH#xh_r+tp%P1qgwi6?NOz~OO9}$g(%s###O}U} z-{<$%ADr2pGbiVoIcN9)V>AwqMSY|b8q(F>*3N~`M^MPh%k!m^o1N$VXZ}~O-*^EB z8QA`0J+xPeFC^&Yv3H2fn&$Sr;&bVkFyd5Cn%F24;u_vkmgp)+`UnN0u=x)J0xP;m|67;~GICpo0Iw0j8E+^fN~Ez4O-bYCa-6HpQCs zEM$X@E5L<;y9CEF)qD^%K9g4=Ue*8fS5wH_@i|ZM>@81d_y*fZyYHNu{)aH>-;Yi( zVT7}>>vqbx(92HuGdPIJGuJ=$_!q)dZgoxin>o8b6q}!4_6{>}Rbb%kY&-597Z7hQ zjeGa-uHi9N8P56P*uR< zAKFIHEZr&YGBDfkf_d8`LpJi2-kXwyy6m!u@riTN&rwRRJJt)mbzX}IuYJ@+;qcnH z+FR&tLUZgHPv{0Frpv-hblQ(Ml7$bhew5Wrm~RSvH~KapT9LB$_MuqU@L$?weZswW zcsv8AFdy5!tv9bY-E{1nFz0}}Ff`%l)6j62u9xUqINNN8D(xwZIG>-LFl8O_o!_3$ z;NJ}k@oO{)US)YN%i!H1?m|3;##K&dIOYqTyoFp*+{b>j zig<12@6*rMOSzUKR=-6{Vi`gE>|I!QOyA0f;mcTtra{q=3GJV96+1IxBUnEo!QEU9 zd?NP2S(T zBSoIp5bMqosbQriJ!i*FXA-Oji6u+cd?g(55LON#praB}Ix7QbevvXdmU<5?TM1~7 z_o{vgZ>IlzR!D3dB)JXG%l(N#gVw?c^*LiR>=Zz+MNjv0w|s9JNR0He%=Y8z(_7+o zx6FxsBX@b67eU~5_(N7nmwIkI(eeR$IO7ixc*NYmNav6JM#l(8du|Oce^P=hg@%}X z>Ul>Fq;ff`pth|k68fUYOem*b{myPYBz*CAO)k^G`f>TNSx#bG4lAaa7<9Iv>QY1d zC+AG!s@jD@EcP`J3i_+tBN1+Yp6FC-N-6bK#{)+^fbLF`>`NM6p)dIyCn*htYx&{& zwYq=UK`VxfC*%6H>Ifs(X>v9E>qi+p1F{=6V&jAZ2w8VCvzaBeEjHQd**tUv|Jw#8 z+^Ztm^HH?=Y4eH^1sCTcwT6VBlRZZh&KvfC7Yc+V97A%op@{07;jx(k&(&rYd!NVn zKYm83`}Fm>MaubXQ)$y(#fI-c*1nU!Jk5XqeuJGo@JTNEj&Xv&Tgw3Bi|*kTW3M6m z2y9brF-ev}6-{AM#S~Ej1kFyVwFJWdOy#B`5q%F;@@1-NZ=ow#?PJ29t%1dord>Gn?aqk;uzO%{}5+}MLmr0A@ zfIcvE(Xw-6r}IHXo&J=+bk?RL`+b2OvN)nbDwWeEDx!k*j&jlekA!4#+?DiG;X>-A z-{gp&EZ`vFqld`aBHksL=nTC+Z^f5QCn|4LIK>d6BQT1;=1;m4DQpWpT@TgF`)*(y z!xp$+S{yQI7h6ruA@T2d2GTJTRVWagT2D-s?8uYl+|_iceN`QX)(dx-%$XuzI;u4b ze3H`$Wuj_`!|NoojfOF4WWtnO1brXB0c#Rs>XFz!ubru(D{PnbjP?gLJR?(nsXF%C zD2&ZC?nNKTqAwF0C4Nov%zjMzqV9mlua9FejmX^389m&+*fTRfS(?&FZ(h5pUd%F? zCl6Vstc?}p@>ixf&d$7N9-@_f%Pg6N^mwJ}E*EFBERaMPBB+S^Y zY;1J~L}tB669$qN?!LwJp^Cm8hR-@qYDy-n6p>+xZK7-*X-poH$h9wV-M^Xq9K`&L zKapGzk&MP&=`Td)? zJ~&fNTtLuOH8b0>b=<djhSBWl(z(QjBWB4+X6bAj zZEItJKS@Q({%PFu(rFZdv#*rxKkp~g=^_lPt%FdkVg z(;dLw*PcGqp71s@J{kxDB=L5vu_zpROJw*A|G*RpSB#(gGmg-=)6!$*5XQaOsG1%lsuiRvt`` zmlKOU6I(l@T@sn{py#%jY~NNp49Xe)PG$|9jw>faSW{R#Bj9Ov>~MEFHlJH_yg|!N zUVo6HmlrKX>O5B&7P|SG;aU^t+BFW&p9W|66DGpe{muy2t5zl~NCp?0^=^5Xd09a2 za?0hhG5xC$lkel>XZWYS4-y4yx8z@4#*UnKJPeXiKwtL%bZZ<3mrQGm8-MW`FWxdq zkB{w3ni!J_$uRs8nSk0dqCZ?iV}41NGxFrlFo5JrIF0SC%w#8FEOtWYTXr*iHKlS9 zxoE~dwhZI7@;I|1W7_lGAWI9&%I5D3?$HzaMW$--I5Yn^5u8#!wuhJwpLMv=lW0Qq zNG*%G=h-(@zMS>-F>8`bh>h{)B~q~sVbTZkF!RYqB8{IysJeVPdrD(7GpC ztNgy}lMiRv%Q!@I4?=!)E^db`AE<|?hsP*43lh7}gfv!ji4x@XhE5nxzV8Z5%O-lI zN)TnicPU2PrYYai&-4Tnqq6u<+K9VlczUvi(o?H$^NM9Y!u(UY--LRYKMk&qiB=AY zLst@Mag`5OI*kgs>s<4Q>Z7}jqDObVCImomy%B-6i&o`Q!aoD#yFPlsyDr^NiPN!l zp5pvY&VI(aPe=1qeM^{pmOu!v-XS_;SZ?dBF4AnI5)HoB`RB(N_SV#+i40qYk|iP! zX{$GZtULn1MtJtmPsW>-3I5l2q&d`q0`vN34qxtmkB!YN&i@WUlVOh5)oOZ0{{q!* zg%I-fkz}I3Z0gg)e`U%Rvms9<{*p>uf$9-cjCm)0Hon5+KUNmW*vmPFPYTRgHQ$)3 z&S|=#AH!%O7Q8$2)P8^|K82r%g>CDtcttZ( zpW45`JnSS_qkNp+5I4_SRG~EKC+AIHwD_yEQ=$2p)dOF&?bSXH36cFj6CAH&>38VwqU$9JV2yp0RytW4&=I3Yt61-%kZ|xanW$yOB(Y{g|cD z@uQo_#L+emp(%$q@XQXYXLeuCQ8+;gF^UUycd6q(P~pP~> zcpt$brQ!P1LkBzZnf{bj#?$v*#ke)mQ>#wHZ;^exbw=@&F!{UCr_60~KB1o5D+cQ< zS^J<<%R>i`CPmj7S6~L)&XO#}NwqD5?KWrq`D(&YY3m7_fI_%&H@TS+xKQ0PoToCf z)GO*z4fiRKY(;5o?yM%%)i9=s>qZLiG7JGsWV26z}J$ISdURC&Vval&G|!^&uI-Umu! zcQAgL7gz>~KE+B}vr&)T=0cc~&;NS1_R89y#s1rm9dQtu`aBGMU4J>9=&ghbCEijp zXU?U=>MUPLz^uHM+~?sQQJPYzXH11T$CRs`nv+>EaRNE8)a@wqK4YD`z3nLBby_Y& zc~jGBr23Nss*_v#i5trorJw_5jHh^*U?NHvMX;G%U|6n9Wc1&{d6UGfoOyxM8RPfp zT7lsSkjUruMbCnnAPO^Yb(8AjWU}|Obd&7*-gDiiG`U#xXg<59Rkw)qOIg}YjMm}j zmoupNawnG2&3PP7(j7j=?i=2X!A?SIyHMEV&ZegG_^I6AfFxRt1T-amQrigNLOLx^ zZspt&*Ibp4>xJa^PW)o@@I7VyY4u&Xm7CMAQS%Y;SUR~mKl90ao%4vv2ICsZHokrl z`IxJ@&a8LjqNdL^&-=ZTVvG*A@l(NL1vB{-qfxSkgugnl1dLgLd|@z7OsB#%#N|@A zd4csH``s0G;S8RHqw`b=nL6Un7oFlauowGQF~{(5(|W@9xZQ(4VMYm_3RX1fvlwwXCW^=W z<|D6G@Q7%7C$uO=#Ci8OnW=SzV#$C=JZXoIN1wV6vn!7K>()oLER2;Jh&vj1Mn>ra zo((IUq_dF;PI^pXD=WWEVQVc9qgWrqKv3A~pn>r`sjtRO7{1|a#V!1r+qYHaBGHe- zCudoyh1iMQoT`ls( zoEt}ekp;ylKV{t`oit2WwAoesf4{eK;E}Bw)exlx#KUk*V*aLB z1J9X*?1-SA(tztzSIO>^OlN*Em_TkylE^^>ll{+bB z*j$v2cF43CkzdIN@VBdXX3Se0!`bt-tAefczxD}|{0Z26L)I&H;SW+{w>Clugx`ZO zeO2xbi0=-RssorVSw0&k==}Y7C;$k`f%5*M<+lDS5A-BDNabYX|T_Sg+CgJM%n;+wrv72mQ5c(bSz6g4OEEow=ZR#nI;Y%@u#QABkrEB^ZTQ!Wbmek^?uglx?S1^+ z0L%|op^M@VL<>$JhDA^Ne4h|NVq>~$Luwgyc2sn39wY^76)}HCM$7+F3+zGM@W~}v zAn6IK-WFY8JHj8BZIPBdUqj{-dB#I@w?j;}S#Wz&JLOL}EV(+m@lrKL^G--LJ>|sJ zgu=}q&#--$(9?L?VL+yV-r+@5W>@m`vz0mjmPn!mm+@2>z0ok)Zq*IfB#=iFV!)~| z;&g`}{w-WYxthbqU31i2n9PMBgip3Azj#w5@KHFOg4vIGc4Bu|40l)E5_;8o>|%lf z$1xZ^xsB1-gC2?KHg7riRIX`qjMiXlU;QDcI|kgRv^$ZWro`@T81D8-H9=(PdPXQS z{%P`qTKim^M1F5`#VK)H zScf+x;>#W6*pzmS>~0i+F}De^17g+Q(tyL{7nGE9Z;C#hJ;L1jq5B@6%OUjCVcXm?piz*0BTBqw*>P=Pa2f2$&$X4B`+RLH zEy>9)B^;((&6N8@Q-@cvhed}Mt7nIUp8A#~k#}g{Mz#2X@TFaWYH#}sol046iirnB z*c#u3F69z=*Y;}?)b(K=!}z4qhpF5}f8-};vL&z$w!WA2Pg^~IZAM$%Sp|OZ(2Y{BtL^!kxrTz!6a1OCLkm*v-^) z?n-%ZH{v`urQ$i+U7xs?&|A|%hpp)F^*eP0y2h_1Cz7!f-&EV-c_nU`AqS#($ywi< z8`hdLNTWZ*4fD%hH`AX$jX9kXi66=C0(9)!)QoWvfkk&g{qJxC(c>6I;ON}2&p{vh z@?>X!3OCRuiUkwNO(yj+iU$)+f(y34A8~QPZucV%F8F>P$xbHyN>953jLF_Eb}xs8 z8~FNOj-Cv@NM?wO;3Y$J+>a_;cZu93WOu_lFiL#`ZkSchv3tRHy<><%^V~W0gA73mtUXYeP?+_0XdPthj4T~9xCTn`F z#PZ$=$ElnAd;$yq&hlL=Wd;~P0=QNRX5#)`j(>BeMtV4&|Cm5A>%*hpYH=bP zfug4?qB%n=?ben#;!yF|!j}1gj_T#S4-cpa<2LUKPe$5>1M=!>I?ajNA3Jxzou`8^ z-)a%~J$uDJIK@71$d7cxs2FH^{pIWNEyb7Nb|-IB?N8F@+t9S5;j0>I@HhoI> z->Y3fA~cPxL$i2TyCukmbtN$f()p_~O1!Q=%T^vK5?9f}b@5gzS*7I^Er(JK)<&N9 zEX!+6W@Wf?{FA^^nz5b%D1|alQp5wMNf(*jV-a})2un3*g;($-V4{a1*}=Z*A5M4? zUi}=HLGU@olfO(un>h zm$OrQutO`x%ximuosC!t)-mXX1hR0YUt>W9uutWTI+=|viFtX#uM$2{w7tO1$CBe! zJKlT{sAz?iMX)Jz@dim1@{kti=272>;iW}?pk&Lqw}OG(H$PV`8pOL(Jxfi_KY}dYk{o}K14N+YrTM0X1>EW`xyp!TPs@4%v&?W=pgO9|=wyq-m(~08#F7-$6ffP69HDXi5kuHCV)7^&quvr{ zHxPdp6yyp_hvJ4Ze+zd}oK_&7?n2uq4wJ&I`UNyRpyM!%A5xEC=2d85y7VaS;RNc) zWwlF8>JnT{;Kiy2+IYHBp}JCdRR_Raj?JiYro!|e9G4WpkKv^N;35va4C-0usWBeP z&DdAikJgL61)Y-wSsGPuGgY#E)L*FNh@wzy`xr3{irM!D7%-M;dBif-H8B`S(3j+w z=WzV7gCamMeiO}=!>Y-$mOekni-6)uLcohmma)XdJ2Ccg&8x_!kmaTk5aTO@EMQ+h zC?k8W?u~dm&MhyRBs)}EW!fny08}N=VZ-B_Rbn%0!)%q)Q<+Zz0gW^QH(%UFJL8V3 zZ%j{YVd!FrSgILq;Ep;7QZ&8pMoA1yqk>7I4Kp7T0*D9DGi%Y$tsVNc&a7b8n~;fJ zC7_4>Bpb%oVkoJmifb&+zOs_^4fT7Hyu3FC{c6v6F|CW$ywV5~K31LMlzjzCKUYZr zrFc~z)PDAO9sc1|>&Zs%Az49$D zFAn1gwwb+TwPSejtlI$lf;n3%8#yY{U{lO*4`2EOJq0!~qU%dPJs-+VZX%vDcdjN3 zxhUY_z4ZpS_?@a?pEU85o0H(4d|ME-ihNIB8VD znk3eV#A2A%Q@|h|#hy`ou`|U&cf|CI;+|)~26~@%qoZgdTA4JTnus6;M>uCMOH3J6 z&^X9MzR{^ZP7mX2Wbg(~nJg2YEE*IrDK}OM;V~-5FIGIqFP=;VzGG(3puXD@7-iA! z1icO?DC#V#K%c$<(h64WQ+L ztaz!3R~-2}u7!j!hcol=BDbz0H|%M8z`WRIRP9i4+JbnR2h>3G(ih$3A@mRw{jj9Z zOf;I`Q^4#6XZ<_Ipugzbf#lI{r+{)jGs$GAyaBCXNd$2D7|M`iUYFYc1ILdSZH^5p z#_ECxWx$>$wXS|iK7&H0J+pRhkE>p^HfY-kkWy%LeAQR`?!#Orn^{~7I%oIRGnf0RV+?NuRE0wALxJ*`2EM z302S+^zD00o<2~%D2AQ;pwS}PlrO0PZP4*s6dXNif~`IUTd^GDqe0M?r+_=GGO-Gl zr4p0+w0>X53%}yYTHppGD>GI;%eT&1hInhMcc-)+WF=U#1YAl&|BSlMQlDaXB4VI_ z@Twq0m*P-4WM46UkJ!ajjjZstPI-!r%+yX&~q!XqqdvakehX_8y$Qj`fif&ia5T zy3SkjQmrvG*4Nx$%E*=f(y+E-+~0miJtfg~3f6vtRp@AvELEg{BzhH|iK~&45bDxV zGp?J}bgq{S`T!JhE|hd>)Kq<_v>jE7re>|Rw#5g^fx6Lh3Eu86vJq6RoG98O5suu{<7y{X zVJ63?QXu2*T9Onfa7QgZ`2?jsrvJk+JU0Pn+ zr|E3ux^Ayn8E4jpDcgv)Ek<7I<0)wbVI2cK3MXhH{n3JFT?#aq^%R88f36ZGo|Zy$ z;p;raoR$b$z@`u=X4#wMmFS0m3Xpa7Ko4TT7*W0lB#{x=WF8i4y(SYCFNFow1m??Z z#x)z5jy(jAA7HpCEhhr=SkOYKa~jWSU6f%;>a);ohkT<(bVux<%pd07rhQsv?T^Ml z!+7+^Xv+1p)_y=kiraWkTmX)rAljS&)KY$W0Pw@Kt_Q z%f!ffLs~qk4e~;hK--C}mss{C zg8~SvYOSjS$gkywa?|cUO=sm%fU!mdfSY8uQCAMg%ZsI+&tL*tBw@_HHjs(EY~_{a zhIdUHOxX>M#))~}@z9IADwrQrY$X5KNYWs_o2a4P3eoMZfD#eJy6nD`9^2< z9g{XATp!tkBpvLDj-p+(G6v{w;fU+3Dtee`43y?U6@Hi~mI-N)sAjMo{v8WV;Cr=V zUmEVpGMqSj?mMCAP444OnnfwLr_XVm`r*H-j*S^Qet&!cj$69z1xqi_iNj>v325J3 zUMVn&na*dlifI#5M8lr$vis*C6znTugZ}zQHi7UjoONy;#x%>u!Es{#aVwRw5BN$P z@fc6SaeqfHr5LNu<{w8tg|R<)-yUVnz5@9ZUKe5G{jymP_jhb{molURZ7RAK{AukjwQfP zV}3zVZ@X0gnNy~hHlMbuHodlmwv={&HoCTgHmml>!s^1v!s5cz!sf#8!u-PI!ul-| zZCY(}ZAooEZESG9KwTAEobmQ>C2u@G?+$Lj*uxt(@~?ducQ3*~WArRL%mfpZHxzVM zuvNsOQkVBfh!v)6@#lvs=#EsYmiz_d>j^d2N8s$?R! z_QTwE5YG+%s*8r0p|Es?A)Ca-}9X*GHj`BQl9fIVq~X*>g0QJqvt~=-9IaOrx8sa zP)(wW{evn=L*ABt3mkLu;BpFUU)F!cetMSdCG$`l4+Won4MFvl8eawW2psqjo2Gbx{x#=~)56#+KKP+5CyTlaLL4OKR zcRoSYdh~1uL*CiG%cSG+4SMceZ-LbPJ86!EyE5x5kCF6vhAutX4Fx?%Vz z2DL?0U4bq^&@Tlzcuvv=wc8wwIX$oRqJM$T1%sq>Q2Vm3Ulu6;j5)Kr7m3d*i3uGB zz5DK#SRE%o1k0y-hZStaeEb>+M2EWH$FWpwVn=l6njsUV5LFb84hL5hPOOYf` z(*wyRmc?UseS9(=q z@qhjN-_JsBXkOk&#_Ym^sK0G_`5&j+@(Mq)mSRF@fwDyP#vL|qViRrrvREh18j280Pjog!?=tm?(CqOCI2Cl$g3hRSyv2yz{ zhL`j#RbSA`NINCa%kZE=`r+Xs7O2*PFOSxy^*P@qprZ}gjq;gMk3AB7_UfCy()(x6 zou9=1vbHq>W&uZ3#bzhKkm!-6L|eewe+ZQZ_S1NH^L7gDCzWr0OoH@E(Lr9^L0HA8 zCTyVUW?2;*pIE}0OIZ}(It6_nmpqL?i0)-3j!PIo8(F0bw4zG>y$7!*cqP=tY<{Wp zeL)|22Uz6Wj4E+hU!)t1cE9s>erJ>ZDP9w8*%r`3?JCC<)5s$%Ev0P2;AUVD-8!(O zDz3NEX40Mm^HU~Ay%|n63qU?zKFez>VI&agC0jfKvZFG53grQM$6RNr&WqsP5VEgrc~|9OOkCxv3K)SH&{x2-OjH=xCPO9B{vk#($!1cZ>_p&3qo*V-y>mB9g`;O&9@{@y zH{d~)FmM~?HvW>sigiB?XWtxjN#EIyJ#88UkGt2P+mqWRA1J>+tQQRdwv7R}a^tpj z-_bj+jOXNhSl8&zO+3MtRVpQaiCiF zCr%ix^;2*v&VCfs4cI4uGTh5&q8^3z3E|t8dZqS&DbPJqF2UCwHRJ|Y?Y zhbP(~=40`Nf<)#v;=i0|dyk+u%w;lAEWpM3J8DDpP+gc1PIZbJ9|6_~5%AEk+L5)Z z>SklL7lK9$exkeyTE+ps7*soP{vdjnPxXW*_!ZH&htO((E`@_Pj~dkn%T>sqV-!!) z0TvIlBT??yk{D&t?F4blW4{R1MO7wpYiU#qmJnU_6NlF-=$hCFsB!`xX^>KkkrkVI zCE9drrI_)8bhv|lZ&O&<>wTGg)ctwbW{ot47l1MWk9Nc?D1^|1g?pf;IO(6EPl0qi z=<9LR9cjp8WP$3MN2i{b|3M#l2z*DKP<6NNT|5^0DxK{hC8|lAAU@DEs@93s;J+MT ztxCg~9>OsPRha`mukS+_6OBf7%{E7IW-lIOOM<*oP*^k#pEQo9MY_4%iBio-0Wa{N zeugd*{Sp|9oLChq%WxcjAcz)#-KWm}$9F@@_c)`3&&#FJvQo z%r|_+lZ~Krf*^uY)TQWKeG{*W{gcxRdP_Bl4M^fv5u!rqQ^Ap>2&SjBW@Ozb4GO@m zBE9z@5mm9*g}XBo%DZc&brMf@Gg%kPQ~9eZD!fGON_ zCzNUvsHq|7%BLY%2Il_AF3?jom{0Y$o~(=ItBpyE9{!$w6g!??ltQVFb3U@GREMWl z*-uZ;kKc4=Co*ky6^4wxJQi?T&AJ&t)-AieVL6)HNjB8cQV)RP=H`SZta6-!d7pXQ z&aB-vuAVNjSv{=wPS=#j4{JaByZ4CuyOGhZ03HFoErFcqoyRl`z#6o@rW`u6nlaJ@ zs67IDKa9Qo)XN59}Bu z`ijK?OKt`0#(}}-X_LFT$DGM$k_$4DafU{MlF}7{^JqB+up>$3&x}uRc41A_9&hJ- z`meB0RZ|vSIqiH-e0D|~o-eVOr&LJIYONqlp%9ke#+%^H;aRsw z8qZ@JEar^rCt(p*GGv^e5xs_?>RqGatGR@2)Ws>yQ}|YD`HY_jNqGiK|EQ+4V^;8w zD!Lzz8fAe*SVT`|d(c03atL*EptX-hk$mQFI>Z?yz<}qQ-FBBvCJZVa?svt0v=AY0 zEn?ar4MdfGt?%Ub!NXc^5Idlts;%vM~V zX4O$Zm*EM>P;>_2KM2a=cVbo0TE1nbd#C>(7}+V5Y)O5va+TWC!|sV}g~Erxf@n(| z+{IhJrNNgjJGFCErdI&ucKl@`j^Fj9e}=R^px2jH_&&1aBFhyPYEW3{qi?TyG$90<}q zuAg7nN`7<|V>0CeRA!AdkGcfhh*{h)T77@mPB<*p4N3h7rnaqbfs?gd=7i3-XlIdf zw(B%ROf8)bt_a_p(|STq+HmjFb{4XI ze)v_i5v;I7ZmT+Hvwm~Z@pmfZ$_QI)eS5+oX1g#mu-?7)st|I#5om?HGJ6=(Z0)+6 zZ+t9(TW+PmYj4Lp&o-@K&O7us6d6V8O0#$FF1(7|RQXO=baTPfJ#uRJRZuL5vLjGZ8Lr<^2*yopGbZdUL%qE(^X>=|Y#duuwjy^*nP|cu8P+ zr?Y4%al2IBh;-+kJzUhSgFm$dTkPHZy4VqJ*@AS+4Ej|_Q$q@E$Jbho-JJ;zs=Si| z;@a1;svV^e^z{dRH!iR7UM;8DeRiMP$r{}QpV@<#=16EmjxwdNvPO&DCrgnFUn!xt zznQnN_h^wZIoC%Y?a{BE_$)|4h4*V0tYnh#r)&t<+5Sw-AbrZ=S>TebAO@?S4TVOb z#+P+h5L{PX>PDM3t$+w4xYC|X==D|0-7Sw-10?6#y>o77WOMIm=b2hzd(QQ+MJcH6 z{m$C~s*D>q)|=M3i_pug5$8^|=Y+<~~mODOo3)lje z9_~SqS(!^qEvsGLy3Zh10tP0?*k_PN>cBu5cg2j{aaQev)G0UpW`Tp6HhOb$dZEA) z2lw7CC61|_{WiVCNZUY&W9HEk&qP!G$3v3#FVRR2ta5sX@nuOVOI6MMPgpSU@GW~i z?4;F_>RU_Z)!C~xqnzRV;ahR4z(pC!n=QD)$$Mm^V3z9&<0B1?vOwj%GXL*`7KZh) zHp-<-zRO=iOg&1#>z{ zeNy)-Y~8-tAgovnOip75krZ8L$P4f@?;pdWu0azTu~r{ab5pl|98jzbM4k_xncd}| zogl;;A8{B3y6_WYWmGh0FcQ8r-I;HnKc1JCzG-fg`g34hOTV-Zk?Pm&n$N6AFK?1; zuzlGegbY2q`Y5o3(A>In9f#lDSq;nRuT8=4tgH~tIbPg`Tv!pA7y37+!v+C`3<0j! z4zuvAp%Hji!=13<(#V+rYUX?U!v8*$wC8MD&7KX{tqoV)-t1-d%gmKm56H~9Ru9R{ z9UvvI9a`6xhGph5s!_z^p0j}BNDkf;UeuiAx9QC+`oGA}cePy@w#@&!fg{*lN4#LR^i zVg+({2EN`xo(?0I*0U`BB2!L>`ORI)U&vF^?&d1W-56Cq*Ccb6^s zGjc6C|BP|6@JMWLRD2QPrf}un{`3z$_EPitNzdC>*<0;fL~*0njl(%?slN2d>)p*^ z^FjCBq0~w0i_$ws7)nb!)r3|?C zAjmQx8U!zPO%r~6wBY?^vUHeg1j#wO%xe3v#ASus3Z*VSe$w`us{!nHaB%&WW!@rDq5++$Q(_!au$ z;`1OF-oS@r0dBA~DmGrCUc2D$=uuW&ODGwnp)I->c%;!kcGTl7Z9ynawMCI zCczV@^qC?VD19wqcIlS+DzNs1U|^4Sz?wnAV9)D|RsceY;rlr<_|hTf1ya+$Si0`I zH)U$E(0Hlu?rQhz2MtGfm|ePMz-yOZGOuhAwC&mDeq)1MFX4;1vjT0|XLr9KO$t-j z?n@Cz-w55FVN?bmR9EkpSq7-=|AEBBUmw9WD0;!GnG#uhlj1^1c>K4on&;pf@i}SO{)EWXd%Vgt{`g-Zgqj;{#bR$*N+&t%m1wv zLz$w*ZgAT|9MSpIcuvJ7CXWyzYV_N@EWu~yuS6}tVBdqLOowLV95!RkOcSzeZd-s> zT&1%8UuJ{ z#698Sd4`P^UoW)fZjoa@{s^@p79`-1G@%KjQBR>d&8)XI9 z=-Ja!$bCMtESeBIHb)S{%D|W-@I5+nmdSkkFPF`24DQ;o77|x;a1c&vnYu zV2Kl=%~SCSFm)RXLc92t*pn!Q6*!7KYU&;RY&OIyV21vid$hE{qencaI_xXQfcN&-VXuB1L( zyz)oxJ@!QUp>C$fh=Brzogeo~#7p2WzP)X5|2C#m+5V{az#Rpgk(i$0E_LKXzuxT$ zUmDZ_ntMY~0cLQ1#%5gE&faivB!WVm;4Np5Wp|ayYjXI|V8GyUF_RARfMU z^E$in7J^t6=`W$~;M-QzdB)o+>_jZ22UJc?mAcy^{6sF$>cH@NZ>y-1yRO{}4wW3JD=a`Qbq3d+BP}SmpVyq|H16b>?Nhn^+^L7mj!O|C!nl?rmO9 z^*vBI-^{W=dEA2XIAiu#*B$BplIPEV4i{_`S|ge_7wxD1cV^~Lhu6FEDjWX1SK2q_ zeu-7@is&2o9K>_-4Jzz*q&MomIX^%;lR4yfcMOHL)%E7MBPlJG@)eL-j$jn3+Tm6r zNPGyPVl6n}7P5Ww^;@0HD@SBR1Ki5%9(DI4e0mWn^zjnxIEPf4LhYi{{pj1<1h)`e zK=x40B47F=DPKh%zy)4acYC2eqU2BDRxv#^YnP2Z_q^&~_tTLJHx9ich*+>a>WfFu z{gej^>eO#FuE2sQ{ZvcHp2yW~NIU!k%D@A zJ^z0>F>QldQv1*I+JBx$ufQvZsLxO_4UvzdC?D#Q@i{{d-kZQ9>4IC~Lnzmm+7aS) z$hE8)WY4tgeP*Unj}YXSZ%0U*{XgFhQM7`p0U}v|ibLU6CI&T4huR_0dlA&)CuR2{ zPY`f8BnIU$avvekyYgx0p2o=Cw%WA(-}&gJMuF@<4gY1zqZf5Wjvxrv{vX}xU#J=( z!v8O0>D3m*9`#=T|26eILB9SkFaPC@-#BJ&iT*z9|0WJm8%6WquK%}GvUe}* z;s0mXdg<5hD(C)X>wjXVL#VB-r@8s(W!eAW#D0rNp$;VB;t>!&g$w-kf?aNUkHe01 zLfzT*jomo|R+BI<(F66?QIGdEJkR)d(Wj!)V4>E1`@+xFbG-b$&VrIu(OoVt%VH*n z0yZ0_U5wJY{+O)LwCS%77NowQOYLl&?YJs^;Bon)re>T}yr<^`>7MuL#a{a^`qy8`0uwfGH;GKSn9faX$-e;eYE7BzOD$m*mg4wET40s&rVvb=!+Q z%}0+eQZP4)!r@w?TacgprkBa@Jv<`vi{@wKWc*)er@yJ#KND~9XzNEpg2UcS;wqJ% z3k{TN-#=rIDaX`Gk?&h;T9ho|yH^xaq>BUWjsHDJ&v&VPc}nLtCL{PAJz^{a*z;QyECYrFFGHekEPn55y}!s4 z1YZz3)Y$GO?ef3BaHXHJ$~JayP&ZxRtF-j`;5+c<_+dU=tJqB!zxDpDIkM+aRxn{4 z(^_AgVo4kr$X^aY85BI`(W-OOm~!_M&u_O`ufJ>Cjd*j`f9`ZauHepM{+h)CGDD*t z&@*j*|Bn0sPp-R(>JUf#SOFzpLL&661aBy0l6;e+z-9qG$L3Ql2#*hN``6qT+AZ>AZhRt)6bYs*5OM5#e{z3@}QV-aGu@WBXkN#VI|F(-lV4*^T`E9vF ze&4m-X9F>wZC&-(K7D+;|4|W75|U^>=-= zNNxSy*3#B@$vgVHb3>CBZNJa6awZxLe6M+EA$2zz=TntKnfViq%3yd)g3%HM{Ni=5 zk9vR~YJSh@f_~HrV!xj~;d{^iZnQ*x&wiV9bGTh;WvsU=IrxvCH_@p0kDnEZxT&Av zR}XmoKXF$STg)$0Z;1hZvFB_Spn{sF7SP$1`~eBH&@MfSLM=aVQxEV(%|GJ4$k;@w z6~5HjHP1!y)k=p^TEADP=q=Hp#MR9L%?NadbSlHbmb;!plB)*fuGi}HV7W$4`pd+*hKtO`rHsEv~G;8!$5a@~V0 zaw&(ST;|aAez!>>=$dB8C*|#uH!x0!!h8QBD)6droZ22^4hlswCz9Q^Hxb+ysoQHZ zTBRrytopw1pG8{yPc-`ZW|q+sL%`kI7abHXvS1b_LIMAf02xp(Kr-aFKe~a%j{J|M zH}+zr>B?z1D;ZtqT!@Q%8He1u_)joz2%#g;LqrDoqDBEC&!!4pz2fd?d7_R!o!j^( z3|A*c{$wu3me^kvo$F%`%v(_ps>=hVZVOM3MTyChL5NiUVi^S4uthCxvqmpd@g=_e zve>E1BP&SKb2J%$(dj`Wd;M7X+=kASRtIGC@K@(Ak^9wFnyk@5DD8?Tt#&Bw^5{=Q z=-r`E%Ua|d z4dr*_>Q!_f3oHifUMHTstqLo-_0(h5j9d^ItM1)*f_w8(@mzt!AIly}7lK~UdjjL1lJV9An0rD8Rh?cq+tHf$2kU+k80?SPbKyLDVYl44H5p|ny*Zf z6h#1^!&W6KAF5=K3fcdY4{@gG<=1AJmIE`X|14>NO$GcF(7BXz-<7uS|CGgad?Kk0 z|N772g7>BZ|NL9L;NJ^Bas(baZ2kC`j;oA;4=aBE(n0&nj#vjo^XJb~2d78UNpEu< zZyU-ei|)-HLB7^^5Q?CD7yg$Ei{Toaz`Xm=f4Z9b@cxt12mi~(PY;v-JoH8UW`U(* zjlA)ZmczVvi?O1NGBBr{UJsQ*WwZN?nf>_#83RyxTl~%QUoI>SOfev%0E+X^Ki^Q4 zl8k>r_-Ngy?_=#*(_4`sPvTHVg?wJLe@KvIu_L~dGY>CV+4sN4l;TOIqA3RWf0_-r zP?|?PkNh+LX=#wY*Mfq`)kZ%HEE)^N(fH?`LppB=#OXd-ch^Lvf9Z>Dl{FtOLhb#K zT@B9v5~|MF>`{(mh67@i={EmrAP zXhhgFeS!{Cu+7`qD}kEK9S7EPJxI>7~!B)+}+l)Z^#QYM3YDW{p{ z0}Igw58v2Wwz3t3PEVS=l~MG;I=$m+-3q4jzTI`W7;WLl8nNp1rju2Qy=)1qyf>e` z1@eBG-*U#BtE#s^PWtxBD0dS_7BrB9o!6~3<9%>xLiS>&Epz6+K5}{AT36?Fy<8hI zPHerplo2Sll3{gQOth`kfn|!<-B>d!H^t6pD*(y{GWzlE+{R`CiWFeBU8{Kgn!&tp zdnKicazW|5SXrv!BFNY^Kmkce>PfghAv@a#m>r+oS^-LhD}6H=-WA$=df+a1X)-Hv``FU)PpNP9iGfPz(xUcNZ7|M&&zQJ^p&J-d;!uBYc{0~HlVLW< zK9Ij(dx1=vzkq2D(4@Oie=m6Y=8o;WZ_A>2cbaJ(SOg5hW=Qe^1)42X_kyP<%fxse zTmS_oNb}wd23`McD;c)8#mu&x_w8IDNVCJ~y_;KsPo;VNwSc zf}`%F^=+i^-hBM>y?-kf0`oY~)w7$SdVwa-HEHk#CjnrH>;&5N9hez`>Ex_0BDC)T zef$CJe^6-KK}stf_1mYESy$}w_5E{dd)cI4H)|hf*Gb;Z@GmOjUwBA$*}RCGa{7OF ue_nCn{_|gLb`M@Hb(s5f;s3t&^OGhym%jN}^E2S)f96vgdTJWF85sa+&&OQ= literal 0 HcmV?d00001 From d6e587140cafd37ebb313beac8d9c159dd373078 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Mon, 9 Feb 2026 10:33:47 +0000 Subject: [PATCH 7/9] revert toml change --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e31b2248bb..074e0fefb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,7 @@ filterwarnings = [ # Multi-threaded tests clash with multi-process data-loading "ignore:This process \\(pid=\\d+\\) is multi-threaded, use of fork\\(\\) may lead to deadlocks in the child.:DeprecationWarning", # MACE warning with newer versions of pytorch (because they use e3nn==0.4.4) - "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning", + "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning" ] addopts = ["-p", "mtt_plugin"] pythonpath = "src/metatrain/utils/testing" From 68acbe2427f43c8ac11b7b7e69fb29e252cf6f6d Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Mon, 9 Feb 2026 12:14:23 +0100 Subject: [PATCH 8/9] add checkpoints for classifier and llpr --- src/metatrain/experimental/classifier/model.py | 16 ++++++++++++++-- .../checkpoints/model-v2_trainer-v1.ckpt.gz | Bin 0 -> 29421 bytes src/metatrain/llpr/model.py | 2 +- .../checkpoints/model-v4_trainer-v5.ckpt.gz | Bin 0 -> 8562 bytes 4 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 src/metatrain/experimental/classifier/tests/checkpoints/model-v2_trainer-v1.ckpt.gz create mode 100644 src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz diff --git a/src/metatrain/experimental/classifier/model.py b/src/metatrain/experimental/classifier/model.py index 247939aaa5..996dc95249 100644 --- a/src/metatrain/experimental/classifier/model.py +++ b/src/metatrain/experimental/classifier/model.py @@ -20,7 +20,7 @@ class Classifier(ModelInterface[ModelHypers]): - __checkpoint_version__ = 1 + __checkpoint_version__ = 2 # all torch devices and dtypes are supported, if they are supported by the wrapped # model; the check is performed in the trainer @@ -339,7 +339,19 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: @classmethod def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: - # Currently at version 1, no upgrades needed yet + if checkpoint["model_ckpt_version"] == 1: + # v1 -> v2: wrapped PET model added system conditioning hypers + hypers = checkpoint["wrapped_model_checkpoint"]["model_data"][ + "model_hypers" + ] + if "system_conditioning" not in hypers: + hypers["system_conditioning"] = False + if "max_charge" not in hypers: + hypers["max_charge"] = 10 + if "max_spin" not in hypers: + hypers["max_spin"] = 10 + checkpoint["model_ckpt_version"] = 2 + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: raise RuntimeError( f"Unable to upgrade the checkpoint: the checkpoint is using model " diff --git a/src/metatrain/experimental/classifier/tests/checkpoints/model-v2_trainer-v1.ckpt.gz b/src/metatrain/experimental/classifier/tests/checkpoints/model-v2_trainer-v1.ckpt.gz new file mode 100644 index 0000000000000000000000000000000000000000..ed2e5fea31ea011da22b14f72dcc4ab64afd75d0 GIT binary patch literal 29421 zcmV)4K+3-#iwFp!z=>%B|7~w%Wo#{WGGBCZVQFqY^rB^d z8hKXd>j;YQitzIZ_F-iD@gz+CfS(T|mqrBngvuHulpKo)50o-Nyrs+_fx!7@3X7D> z8v4n@7{5TBL4Hh7xO{-`7^$CJN7jgEG**3L>P39g@nWRDkwL-!UUF$@L>S{W&O%4l z#K?$57QtKk1W08~)s=OdAeRnO9n+CD^Q^*~n8})ZiuqT+@JN|qzYY?)ti?bV5&tGu z*3zX)tgKanpI=yTFi#H(3yshj!1zlUseg|kr8lkHyUE(fWo_ML?L1ACE$u}xUS6gu z5B-9DA|j*_rlc8(cc$*-rJql@%qULQ-q%w(L}igaq4FU41X+i`dY=5a6_GwP_S8~# zg@??9bg#cOg0uwhg{)&>Ez*3x!C`)5`6dkX375)Dqh+1Qh*kBTmzS5B%xt90+@*?7 zM5J$o>Z2&WH&uh#B3fn{=^N=;mp|;oI~yeDLsSmaP@fQ~%*wL?G3SSQP>3{C?h~vE zCTMh!l#y9`))Qt1Ka=EMwbdU6gbIiI@rRXlhlVjBKEXi~eMkuXNfvPHH>f{2leVZST@k}PvxfqPc9AN88~5{P5LVW9zDe9e5m@{E=Z zHBzTyAKzf9Y*+$mqsR!Umvp>Ou)l100+}@tAyOYEC^SGeqL;(_(d0p9^w@}SA3rId z8Y9AjBT1|ydpW%8wM}p0-_GQu#7{;k6CM^!8iI6Bglv?{ACct4hzSvV^j>~pq5eV2 z&y@f#m6IVpQT%*XW__8ri;imM@`obAgFdF;dNQ%;-SnX(J$77HAaMkj(s`Kt6j1HVsszXLH$@Kvh~ETauQ8m1#L7AyAop z`IHjoCJX=d>?@0tF}{KN@qrBz0u2%ZI|Pmn6cTPkAfI3Z`Ftw>F`q_`l#Ns8(?C9f zj#nm7WzS2)!~6neQT>U%uQWpLr8+X9KN&BoNgAPa!^@vPbI2wV=O-DFWEl}I^$U_l z@ZNisyj5K!Ojfmo`rOr*A0EQFE1EQpg=|Wop-XkSY%1S2F4bdY3T4wwQ#H+Wxon1; zEXK17Z&*4}OwCNm=rtAo6h^#hP-u9h+$%UJlpj;3fo9o%Ra*p^0`9#R(sRxA$X(gMlLM9Sq ztZacY?u9DcBDrj_n{0`vU4?P{)8oQ;qmb`LsiY*eRFoa+B>Y7eQ+0 z=SB%)&v%F}vb{V)OD@}|L})4dX1}V74yZ6$avAF;L(ldV#gIPX1c)*_n48J6Rg@e; ziIp8x269M6I4qYPag!Y-*{FO1KLAQ1IrcroTy|X5TDih#JE3|xsREsn%TBw=&UkjH zFo?u*_Ir$l?3@adr|y*V%BKq|#6`L6lAG+Zr*%09KdJIeqC8jdANXkwvMb-yEM-@D znzme)uL?>#PIk>jTN%`K74e2#cGFFE%hU4Li6pw)-}9_wcT~|8sH6K^`E*ysx+j<2 zcauG+D2u-k<_~Mc_V9a_wd|29wn8O~1oc?eVNX=Nr*hddH`()Yr)wD=iR{JqG#lAV z6|IQSVr8$Ck-b)t-pFNd-DK~|&9-E|s&VQSYD7tN@4x4DmVHq1iq*V-l)R5B-Y2>2 zvzzRTr$xmg`*99`{T^j2`^KN{MOEaCh@aC+91o6Ts_=i1W@WVacLdCct0*tTB>b}* z^9TRzY4(dAD4Uk4%9s4!g)U4rg@=D+F(X$k!@ zUYIitR9%OzEXdF`CJlrvob^r zrn!nJku!#BVigxj9MeL{lqlnFDNtJxYHLDmqog+fIh91&whYsfX{TcHgU+ag8OyX+ zN7+H(855of;dNBXO&Opy(^-hmwluyj>iFygm_30}1O`8k&!Qwgx(v>SaS(8hLVSEw zT?tV6uA6}DPLMqavggm^vrxt7R0i0Y=_LT2OXKUUj<1gZb0M(41lCUpYouJ$hK2YB zg_f*>m=KM=>t6QEmci5J9;U)L7#+(pTX55uQKcNtL{Y z|HUJbjV?noX99%C0!t$cQb#sMAdMxYU_uJ{No2xM3@yX5V8R4exaL%o36U`ZFM{yo zgctdf$kgL-Tp5}rGhP^nQKjQ>f;zH^0%;N3rZ zmds=l6xq`;RV|`5N8-FPKpSShaK1?t`iqZjfe_h30l0_&7Zc!;lG*(4K+<7JWpJID zr9y`#mv-1Pb%!k%P%8*3g`igcEH+^*t||kwWmXHYRL!)wMrh2n0&pDxrV-%!pT(vg zj2p_}x-c6BTzYA2o7Ayw7El=kwS}Ozmc&K|$ba?QwjW_uX3Ta0mZ=%rUqWm<1mI2r z+(m%9OJe(XAc<{H8Jsz@SHSHnjcva=wgUnxi=bG7!k@(^^cz>Mu{>m86v?x5%QArrvNP7 z!zu~&Y#Fu{b53CAX@+`U2=#)%zexC(2>-H@-@c-J68V)fbZh3SK+i9Y{F*xQ>jL=( zA>SnATP2Z`My_Dwx680?m^%WyKr`~cg~;y;{Ck9dpYR`)L|!32iTq(1dS~X5Krbwf z{INRnCj$8?AwMJJ=Oq_tKK{x*+Hq3Vy%yz}!#^@4WP+D(|4@;;9GE+f;LkX3gyOBV}F8&Y#y z{&`G^Ma>_U6HGf|W29+Q?_XETOnYM4LDRC~&(_t9F)3i8S)lRH7vfAuQp8lVNSli8 zA2XduVKdFb7M0pzXUs`W7Me9#R&E=cu_RSlX;#(Ymv*QbYf^)aW(}H7a zrJH6gCcn5*$#f@m^w6xstP)$0OixmelV&|_E4-h@^dhC5HA@@++O{Cmn^e)K5{*&X zu`Z;JzM6IX!ls-?|Mnvl_1CPZ(NA}|m;t1Ot7eHm|7%l)8Ay!XG>sel_YM;?h?u)q z%-ozAOw2thW^Ta@A?BW%=4E!EO4JLb@eU=WhG~|n`)`{q%y42jLesGRzwgR0BZ=uK zO;g=}*$rU4h@JP3XOptK28<7}^wqS~`;VOy#*f(gSHRYskrLa{6|l8n0*Gy(rfuyq zI}c0{u^RJZ>_6^LFk^{bu%=zDANLlR5MmUnX{7VrQkDrL7U4g3Uzx=&BO`{)kKOk# zt70aCn96^g0{^->z(f+`aTPH(XT}rbsEQa{FcXOJ#EKYOGLwk$IKe_4kz(}?NxA6@*{mHBGn?4W`7!(a#~PKHOU&o}c)jVwn14!~DnkkXb{_*J_$q6_!EFI`V1Sj|&IQMG~`~*lf_S ziHz_i0Ta2;j|v}VW63fmoh)M}mn^9QRf-*EQ|Ur#vw9(wAuOb}kcHG%vXI)QT>bo( zhKJd%u~rRKiFKGvHD65j>9mA>I>CpX#D`tPhuumaenY;)?9uc{GYE zet~{~(6b1gRnmV?&BLIk11c#G!)a;k^kfT7kRv#8kT`LOIB{6%#P5iEm?N4lsFXd- z(b9c_W9l|IF5q(s{sh6F)T+RJ-G9~kFsC%RDxnW^x|DlH%{?n{&k=4O;hxu0@B93M z^oO~iNmS{6n2V*vOKRd}fp~=wuM%Q@MTkGB2$*Y{M3oGPxvo`mFN^Qg8$$Qq6sWfd z^){j2(Nb^ZRah9r6lgM4q9EpPE#Z#VU3HZA1m1nZdq8*(mAqe-2{Df}k;m0LhMJ@J zu@L4H0sE9-pAqbHCHB`uL(B_Js7f%zye#d_B6Xav1mJ4|d_#b5mB30%h?sYpY?XY7 zd9S5$XXt}g$(^BMf&UNTe(bWB{nxaYVgM?=+98qEoHG$#edDZ)5S z7-z(lQT1p-OXWQ0G=s|%f;%sOFA(rW0=^{vWjW7fO>C)-=Spd;SJkoR3&?8(d7U6{ zh%2Gvxv2>)74Y2BjPm31HK)`JC*8jYDNNoO;s3kV~k{|9nX~=3CNR&I@29lr2Y8n!oMw0Sn zHI0d76G??-HBCu@W|Ci!)iftX3?-G4)wCdmTS|UkR?~{q)LQa?AggIZs%k5#ysV}j zsliB68CgwxQnrKS7iBfZq?n1MGP0VEq^zl=QnH#(q_CN!^0FFpQiFx$_hdDeq!ug5 z@5pMbNgXzl-<8#LCiU1#DkrPyLQ307enVDcPpY6Kzay&wQU@*hZCQ;2smM`Maam1Q zQlgurLb96f#JGo~g0h;P#N0_zL0L^NV(u)dpsc1hG4CV!HCc@dDb-g}uB@gXG3+m? zfUIT!F?E%cC#xAq?A#>f%W4J@OLs{{WHp0{t%sx{vYH{p)>BduS7>|=w zNLDkR7)MDeB&(S~j3-JeB&(T3j3-OVlhs5MyD5?i%4(((bA_bBvYKh6z;sDvWHmEL z*%(QAvYJ?87bmHptY#)LkC#+fR+B&q%#u__Rx_KFog=A`tY$7To+l|sRx_WNB}&Sb z)hr-}3ndkl)hr_BizWXqt64%mO_KafRPnaZu1G@S$V1N<|V#ol&Ri^*+Q!C=cpxTwyNuQbg3T4Y*P}dD>pN? z3ye&{_=_-hd>??hDuA732%VW-ssJ42%Hb;4{c{2#j}Xoi!Ug3TJJQtErJDAE zMW*TCCK4(Cx&G&mQ8p6UYyKFesq!bqe1u=qbL!Sv^EWAfYAT zerWUYj+IV`Fbxf#P(`FGicwpYye{fLxS#9LDxz_sNrvsn<9|GaPQz&a0h?VLK|w zs$F{%!;$=3eyAu{I$orycoQd|;0P&yoctCf|F@C3iD7%gXv1H8C=;*fd+%o~?G{qv zeSO}0)sIegGV3(XCpc0n1fozEQ2Z>AcGPIwb`|m7%*3#A0SJM7A4KlA6ZK1i(9;Ye z!Y3p=_!pz#yF>L0dB55xX1^NJ4^j9|xUi`tiUyidgfn5`Qbrym{pE3`y7KsceQ>GA zRTV8$@%Qs0M(v~Kq)eNpd@h+KKbBICM}@kJ>VH>?6Mho)yV3b$1iQ^5vt3mnYq5tR=SC>DONk}B&|NF|nDkV$GHmRmv9D4t>a(7Mn?`Jux zL75z7?$oWm=zsa=fBENs`R9N6=l_HJWAPvPr(}?nl>JuDql(&2OZ_TVzbd^O_G!m6 zXa2o&IahF9%6Mn#)XlnbS8(6fba#;ZyE_a{2;-DK$d6LTlZsx!4 zL$=LWhg~DQxj(-4rJ0hcKGYN0mc2URmrz!fa$-mDD2Jzg9LU)3v-EbqO7hJs%U@dqNc$JX ze9p6CerNqU+f9eWe(mqg11ib4{2ji@F;jZ$7#DO~H=Ewl!UN`-wdbC$ z+6TQAZMmnfj8N&@neE-08uE0*a(pu^>l>dGoBY(bj(zKs{L@XmxmP9QxBi{@!ymVI zc+{a!cD2sWxFTsHw@=(CJ1!{)dcT~U9nv^}jE#fM-MlLqzfF1ZfAt!goj<3i!{yo? zXtR17IB49=VGY%sewE$dp<~B8pyY>d>2kS}{Lba^&H4OO^T|{Px`x8Rihm!(zn3p_yf}JVZCmLq1XDuxMrULz2(#zdQ|6u^pvMfxrXap=_^+i zprz9;-3L(cL)>G{3y!6Fs~*Sr45w$tZ^4<7185vS zhr8-&!;Kn$3#Oe|22&pF#n3jM^tPyR^m%3m_vV=+r}JVxJ#^I@y#Dzcj{3NbJ7_V0 ztD3YMxCM^*_o5lxUOO$yEqy%qDrXrtc5g54`G8g2@`BNvwc&2sv}Qbv9D5sHX>H=Z zwFu?b?0<#tYI)H3eKgNbytqT^UQScf~Hm&6^(-N|h!I0vufEAe5k zTC~JxF*mSt94)II%jHg(gpP4Z@M*RkeevA_y7T@ty7`{@-1wqBAkk~W4R2DcEQ!mc1(1vSlIF}7O+nbIbv=2L8?9UBz`hb_~%;cI6s6zksYBtR!Z|0_r z+XySY?dhjlF|f2(6>gFX(00#(ZdkP%{o(FZoEVjj=Cd0@c;FOl8l8_)y#V^W-wOKO z307oGxNG4GoPDhmcOma4UYfRr%dzT1rv*&rsFl`q$jerAoX<46 zezn=$x{=dqpVJ*U#g3}#v1~^}Xng9=o{S$JgQbkwdxrwL~bLYr)kTtjBfg`T!T@XK+V{ zE~b}0T1tyQ2e+v3eedfYCzb)4SdOio-_m7M$P56wR-QYZhPbG9R&3nI;{zTQ~j#r^!#jpuHUCb?n15cDS_woOen=-U2KB)`bmwxXt^kx@mcd!QCZLB+eYwH-!F4maVeQd>ztXN<%d)~v9{~yK+&( z3#g@j$#AmqbJTt_AJ-?C(8m9)rwp1}QvF9p zbiduLVdY*+Tt6{{t6`aik~%XWzv*kxE6~CnO-^Aem)dmSI2W!Uyag>Aa2ypn{`N-? zx!|YRchvc}%kksME_9EP1zEn6deKjFBkAs0#ke!jj$5wXkKW$RnHF8_&e=wy5;c9 z%JTSYag$$a4;`&Jgn51K!uOTC{}_1A8# zmV|x$_P&OdjNj&W;_o_Awg1rX=(y9B1CN`Oi*!%jGN%Al%7wZ~9a+ev8V+U(eNmzG-oQT!-Cxa-(S_`IeRCd!1J0 zKheYR1(h5BmcNz%R^2?!eSY0IJEXl<_GG$Q_Lf^0vOjE%$+q7fm3_8mXU9LyYG%*r zvoO0?(E~0b?LFJ&u5|DVQtd_6usA%VlH=dzSI2+iSP6CzP8gukkbBBXMBE3@A^|Fw=Fb3m)AMVr z&+50XKmGJw$@m*)w5CZV`M-7j*=A5<691IQ105>K|E>IY{o^5$e^Q4({iBln-^zdg zf6~wFefqh11@tqi>H?|%C;Azy-`3A4`qm11(C=RMv=$j*KGnPDv~Jp<=U{J_9KDS} zV%xKBIo}5m-=gC%J1O*YR6gqLfp##nH-U(R@1n^{-56FyHd*W zo2nj|mD1ZN$3~*M&%f8(tov=f&4h~UZJY$X%>{Q`vGTM&tel1^ zCzzR*)TPtCq+`rYEiJ6g&CRTAm47WzfBrN5O>4`#@5Ey(_Kc|{-{xn0{cg3nr?*l$)7D~c z(EeC@&;FfgxabOakN%8tRTXrBb6H&!IXO=pUq{g>_-*R&YXqY9C}NYH15qz z1Nv&K&2;TI9q8f53gMDf8h5CD8a*x{2wz04=FUz?r`tZ%;f&1=vL@Ck^x2ePZfvxS zE2ugOJPz7mTZfg9+-g3jaBoHJJlB8=Khuso*&~HB`(#CDOq&6pdM9#up)+aS!ZT1~ zcr+Xibml}4`@vpW8m(wLh92=ehC6a%1?{-;BIRw|nw~y4hTbt{0Xp=~Y( zaQeGXQxB(ya;G*eK}E+XHfcyJOv)UBicwc7m!au&;y@?b{sT=LtXTXh z=GUkV(R1L#)y16Kqy)~zza13M{)5X1$<4ZwmV)~6?vP%86rDQyF1FehPfyB;qo;qK z&gq?+!Bzcp7RS}L=O*2eaK;H|*r$>W=n`qdeJaSu@Hg+VbLd7|U-AW;N#&f){aUo+ zft9Q%eLB6zX)s-IDutf+aRS#pbU6LUZa$qp--t`P!?#<(RNBKn%-(QOYx>ld7wB?* zF&CG(1&m`)!3{kx?&<6{^!yVGIQs8}^y;4L=@mEkLt=Xm?o-pYbn2Esa9pz=JVq2i z%c5!Yy1n%{QQ|Vp>u>-vj_85u>IAxmvn6MI*NJ`FbSl³?f3Zzqq{0VmKO3v^} zeGJTN#GT5Jz}PL#ut%RwocHM{bUCVpIsHyx%d;u8?z_e4)N44M*LV|fx{2JWF{vO9 zpG%wd@ZuV;8Nn^Ay^+4xB7u$zp*d%1KU$Gtich+ybNU_5P{$H)!uzAAu}1WBc=2)p zmp&($J~ed#{m`DG6+r>;@NxkxDjvtpUc4DSJhf2q8n_irCv&z*Qu;+&06py6axT&S z8a%zOL+2Q-rB6-lNryd5p%qi_;7;8zy5@yh^q>xFxKm%_xu+(xV1|z!ZML{QcWRVA zc~h9r(qU?tUeh6QvHa&tMq48l7*zoG+_rg=H>0ZMp)hIWME5 z?p5L1+8x8EVF##7@27EN7wUmQQUEvBwMa~P;ZW}q8CR^ql@QH;MTPoL?3guphenM54gFnr*|%>#?`pJ3J=dP z;gZJpfGdY~fnxDB>QfCzu5Efl?!9#(v|PW0GhXA49&cY|9d^;-T*8WBtSFQ6To*>? z%y`Z23mL+N8_VcZn`ziH<1jk^1()VhNr#BAY z#Dylzp{dCWIO8hIXk#~DcC5ov&ORT(`|T{u=(CJ#JLLlw9IXlN&tho%JLho^_MvBg z>;R$$w$%I8(_vi|Z*Iknx!gF#6&zH3HRoa4oV(&Ui84Ct&4n(h#l3frL`9#bcx)`i z>D$bpX0LO=;cyr-wv3~aBRX)Osy)Sw+(}s(pWQj*qn9bsnH*L>R6=_su7Mn8DnFO4 z*f^WX9KG0pUbTHRr=MYusWormsF*68_oC0}cqxUhF}geLWR8&LkV-R~mUA`SMnKez zVf1Xym-f9lle2mI8lFaU| z$Cw@VX7d(tfrA|B+KZRd853lXx-yz`GFuJy=RZKZ)-&l(Dc0P>hsj)x4}a1gqRo^` z=xxlZT8Fdz)`=V6ZWeuN;aJo*8Ap%vOr{lYO(1UU0>~)d&*q1Ypi4K;{uX4!`>7GSHTOqZ`6y92~;gJW0kqt3F_!epOnuTS{kisfRhTT9hSB#6E4dz6xd%?qn=W$hMGj8nk8|WSI2I(H_Xzk}(bmH5MblXl@ zkhN$z6b`IOKOH;<%wEQF-LLiL9Hr|~WR+&0w?ss1hl*)`|C#jD;#uJOr4Zvi6l{GwY54Iw!brdA0d*wy6c0t!>1;@T^0NnmO4gt#HD7_HDS-HJ0eqd;yo&cLiiTUT1IG zasVxI%cL^gu2ZQW_QRc9)98JDTX2n!yU_M;H^98zew^-(=P+Y;U#{rh6@0(5C3n~+ zkc$h7qu0((=00U~pqnpW#(5aG#iuvdQm*L?J*n0QR74+v*;5C>jmyq-#*npC(W-CI z-J(C|Xc594ub)gm+#AK|hXrDWP8+1|Z=jRn7vaO_)3~;^w_>ZhX7s2Y)#%q|bLoPd z*_f1TNM|fhaO)JHxCGOE^*QF6^5v4YTshc;fSg;+PTbXiW`#git;J#eO-!14XUG3=i(+AS29y=gHx`ewI7fDY# zei0jAIDx5|^>Ft53G`FD8T8W&ZeV=7CAnW1=bddLQvSaktN-f#OY6?Rd;fA_oz+~% znErNSKPPdcJqsb%G=j75{|23xhtgjgt)pwa=*uO`fb(e39&}Clx$Hd(GtDkzfu%QB zysIu9VY88*#E-QVUk&L+yV}yBrU~>ZrZF1-VL}xjn@GPuG=xrCkWS~UIgF`ws?q&D zL+SjT0bJWvGa)rQ8S}>sr#<>L`jkxMU~!M)z3b3>bTAj zdaE{F*s}_E!@UjNDkzE;FC9SJuiOucCmZREMh*wYc`P9JPR8`QzE7n7{{JtwzwPg~ z{~r4{K%XrJs-q{?T=5XX-%*h0AY~PCcOY?P3_csY0b;g&0k_0Vbk|*qbGCdzUFUtQ zwxSU4Pc+6$;f@poZ^v*db(J2aTP>Qc*{d@=%t(NJToEQUJA#U%N3tGz zKF1!j3Ng6qUC=$+m}L#4v2b-K483(6YaTleq8hi^OE2c5XS+k#dd(7upEH+rG#bM0 zll_76cH^G`kV1^ebHn181@-0cJGJWS}UoXX*8AZ`4A3WdIX+XL%`YTBp6j03`W;CP`9i8 zjm6!wDbE&rFfFwy7-a0CL^`(id$hMB*J%YB+9D(ln*e72#6fVP0ouA}vX^YcxTR($ zm{lE#7w9#xd3q*z8a_c$BWF9us)yMlv2#&RXC!!Eh+_M^S3zIn3cC!{*j9vwb&kwR(r?-Oq!|vFB^EuGVy9!!8 z9-w@Q7!5mL!@}u9Vb`WR=%krUq_De^5ID&YG=+>yGu>oi<=G7X=!Ne9=4<0;3O zrc~&GZICng4c4|ei%!%BXnkNAUXR;{IX!c;qVBiCK!Z7~NlXShqAkQZ-^A2AZK!`Oq4h^;tcHHF?eHn3~yRV=7Ak<~ub4es8^gT(rcz;nh)%A@Oc$W00cdFPH0%r!&0 z_Zdv@bc+q&yPtI&txvgv4_o-PKfGIK3Pr@?6W;9}N@FFP@Det))T$^uHUbwgH8e+TqzYyjcu64ZY_7kzVC z%zvEEGFM)q+2l*WCd)C^)EzTjfwDay0q-VVS&P-%!HimhnbDm=bav7HG>1HladE=d zeYGGe^)NN3%>gj%o&lFqXJg8`Vd&lB8tW0T7n@Kz_;&Jnv|_tKO3U5g*sdk@w)UA9ZnadLuv z52tQWs6PQr{gct^(-#ztx}0U!d<+}KuT9d<)Wjou4@2e#ZED=B1CZWz2gaQ53x?Wq zYF~8)reBDL9fhB8$A`1vU84=^lH5^{tUJHh&k7Lylo4|1JGS;)j z3pS_8S<1jZ7e;lMjX7~!uvx(>aE%Jbv@5M4{bmyT_IXX1KKw8&#%UNOkHR#INnqHc zz}{@M4_otg0X}<>4~k|hvQk&DENkqB+VxMeldnaAq8F9*_QiE@ob)U!`QbwL;nYMB z{WB=b!1pwp)7%*L-8%t=3E!x~t~=n$gY{5wQWq$R1>Cf01DT=j)XfvupvdhRwC-U8 z-g8c8naz9v6363sx)BQ-*3JZNzcdhyXv#W?veA3x9w@%!g_(P*LZj*npdP%$yvNlj zs=ylKcI<}RU)SQquod9Eb~Tv6To`ru9fT4~ zhSQ^P+kuY|*J}^x|HFZN;C85;xeCMY4M$xsXAF1jhwi`0%hnlBrfWgKdIPcjC*cICv z6?*Ym+SR&a(&g8f+IT&c`=&k^-B;jsos|%}{sEdCe~)ytRcQFe48s{eHs|h9HYXzp zYHuk(-P|N9=TXB)W^i|4FmpGBU+)P{e+!sLAz zbKo<)bw3DwY%?)e{0c?qSKAjqtBv0M^?~lQ6-|~5gl)}kVNv6m=)X7)Q*y?GXrf0} z@gMK;nbkrxcomFv+;iO5YzFmAW&??TXR*QAE-2dXbfBikDJ(8n!LR-7vUUXD16I}) z^^QkF_|)F0=c0{kcf7&m2A9Cb=^qf)G095axE50V&%=>{Q=otT3`pPo2D{BG!no)o z7~aMZnuOiLd1DS^afc{e_~t2!Qp~z&XDy<3oLhpsR&9hl?{4VYS{EI44uj6g8<6;V z414p~Ui1x7psjQ|wMRDz6c46=K?7f$BYur*s~4d6gO(_YJ7eehQHtqfPom=T5GvW? z6mCm8f_6>wP#bm7s_9)=+}0UfFEOlH&IL;Mt_z5s(01vQ&f>iND^d5!R#b>xvyKdT zg07toup_$IV?Fy%K>wKsmv$@$#hwTDbB?ybkM1Gw1^!&{WY2; zxWK-17vPEITuheeK+%j4*3f@Rmg3+G`(gRFAkQj+GW@Uxy-jS<(QqQ99BmGY&du$M zThCdX!zndN9~K3Y-_tpi&1lg@P5A z?CIninDc0*z0t)x>^R9INU8S-3frEdob&F3bN&qMF#R^f1>Jl_L z9x$qZBCdLV6un37%9?(CA*M~diMf0Z7&9ybokm`Ow@2qe)RsQD+%p5UpC7?7CQGq! zyBpf(ECZu??LaZ0C+a{Zw64Ay{paUl@#H&HiqTv&+E^Eu`}vS=$l%#^e6HQ=2pg;( z;UVdLNIWIOn(^boJMk1!y|!XbvJbx&w1&QlW6-406>wbeG|O<~47Q-hCMxsRbJja_ zw*4X5Yy8^w0*(mT4x^Ko;D+QA;Cj%Tat-~5vf5GzlYQr*V*MJnaO`es%TrTW*!mv~ z{>K`ei<2-_5l4-%*b2JhzN})=v@CB)Lv};zWzdH(BF-kHUB~MSmTjPBo zvRe%ERvgFoqmD!7(e7-M6c)M-y$v~&ys4;KgK^5)A}|b$MKh<3aACqbyi~0P-ZsC3 zxfxR25-=2eL-*ooKYl*c+zaE<<3Z7W=K!cZ=XjyJGM*o5mI@_QTg_x#%`{1I~-N0N#(fvQr8_L(VfjD#zU& z=NY~MquX~tw=PXNneo?&o_}SrcH{8+tOJmPXkWDOD%Gd!3YeVS9-f_Bjh>q?L1xdD zkW+mXzMp>(7I&=+ISWs)ul4p~=A%ul$JdYO-*h2t3)~Oh0rk;v&~vzy1eh~?JGCKT z2O64e&9eQ_AE<;8;5{@U%cH?G+*7mz+D2~z(cm=drhF{gHoZoP=+<;`SSU47b5JjS#-GMIGu+jV@nEUn^mEPnbmD_sP~J}xz}Eg4wYbxD zh&r?kVmj>wvu6qH`=X0jaEC#?U{gp;-U)#^&d8p>2*oCgD3Otto#D3C_UVszv9&k9 zgo47q?0wIqK%ct#824ZXn)Mt_O&oI?_8p9XylRIj$@9;ssM5$j_;WfATY4R%s?MQ4 z7Oz2_;e{}-&1`fRwE>InNg(;>7)m5FK(|Q(>$)KZ#^DDTVKEF(XOG0i0jEG0{isaG zzO26OAxIxJ9#iR)Sq4|PgSUY*q%T>Ix%K~m+MN#L&8o98I3*4G%C*r>IvsO5)TL~x zkCb!E)u^{_3noARfcbqSaH!*YG_2o*O0W8W9p$$XpLw^!+qp+@#Ppe%Gi++sz8lxE zTWkS_9?Ir(t0QF)=8HL<;Uuh$=m+|e?F(@ zXuX8D{JAcD@Ele#A)A`B&=1o?p2CqSnOLBGFDt${3r85%hPK_WgY&rOU}hLW_35C% zD<_K~@2UaS)@m!vDvrUll_9ul#wGNg`xFvyZKAr}*o%sXohbRT;b^$}w0+*I_At5L z6fEApgR)z@2DROeHGOqZ99#4D4el<%-8Hx`Zo!@4?h@Py5+Jw*3GObz-Q5;|LC)&XaI z5D60UP6KAmOb`6Kob?~O1v{lzioiCAToA*HfJ8u+X#;m7oRorM^wuiut1@r+%W%o| z-;U45!)UN2$#yL?k3=n~?TDnEk9K8(&^kbeS7?P!;(w_d8V)@!-S;li!bjB+t0tC-&S*L;@4O+hT@Z!jcba%u zDhsGrKYThSSS6SxE!<1sd`fGz{`5j=j5idm5B;3el(+KLEM;{cEzc6R1-SotT^~Op zb=@|1t`6r;M6Q6og$^&LGu(gcPw?^m6XcSDjKT(lhjGz{3TyMsPD6{9eeR7&Bw+9( zsXZ>&Yz%sFf5um%!@v>bw}Rv4ZI^FN5-5l*vf=+Ak{?JukIv1AqYWgC@SP zUp!@{rG~?Xp2tS#;w4O8b2pqV*{UiPq?gh_@6B3^b$SGz>`)?UhV9uE%|~5ZG&04_ zA4f}WeE#Mf1L#v`dlcdFj7g=u(M5q9zvx}O!hBwCuI`P>tlM8$ zZ~**@E6m=ZgS^FNUL~pXcx@B{fTh9F4U(t(Cz(ax868HeO}iHo?Aad3u7FNU7VyHq z#7ec|Rkh(kDK)r7@tBMo1B)oyDhtpkj+|we7P3cv@+P#*efwn{+$asojv?1j-(7BaIQb4EF2>K=QI{p$AI@|&%T>6zB0_lEfqOP9Ho#Cb+9Ef-MWD}Dr?h^ zblssDyLh1PN?``d-AKW#pUqKb*p7+fcA8dU<1nnC5WiQjJB{(`N)&pF<+BbIcA14$ z@A<8n{`wQ%K$g&CO>ft-AYr4WG9hue!LWa>ZPk+V&7E$Gw-wdYa|0F{_eV)V)uUjg zUzUB(wYO;5TtlCu)^nZf*W{Kyi4>S;mnVz=YG#KBuy|Q_GpeUz?*C=nb;6+|5KJ6- z=-E^Fs!{6T1#=s`_!n$6VEooq@>yi><0-s-8g@}) zzqP~R*vMR{R|miWb8nM3`tUzrKK*p@^n)sht$_mGqnfO56R)#75Nvj~;0m}+^mO3H zl_^9K*dO|tIajLvZo7NJ@Hqt!T6JEECS*qwe)k~!+>j9lno}0fKS1SiPe3F*2gn?Z z3l=zd02DG?;}tuf`IrJ%pt$BvVMrAt|GKaJo@Tk@hjBDiCV21~Rf_frw%6gqAr(L` za54<<(B#p`b%3wE0HgKQ(5goX#>zA;64|?KAk!D)>#MAXKwr+OzM1pY&OYb}-7}qC zG{Qh}W*aYTFh4dUz%KEOdKq81jR0};nk!gvYt%^u?O!8yVm4OxI}o>*nUZ6i1PEMR zN>=bU1VkIVREbdcVt06kVYvJ{gA{L@e@@4M$z|4hxDonaCca5y)2rQSBduYpHGN?vP7NjWN(5g ztcLs_?0a08mj$t0-`;GyL2Hg41FhF8zGwPe>niw-o&jRbp5F28+~J+h4={n})+r-j z?7O?A9Iix8CI!wagPMsZ;$60nDz1rmzBr+~SB?_M#+cSWzq!KT@$A=p_Z>E!c~Ls= z>Ne*)DiB=9(apfSq#_F)csvi>}dv?>Krh3#6xjcOisK=H&F% z-B<~Ws%D2|y)X;vWIX(nbq^B1B%Ay@JlJl0=Lsh`9PnxS1(^cSg9a96vBdTr?6uw_ zN+a&ZX;3aPecU-xxnGu6fQ+LymRpAGpU|IU`as(@O@!V7<90m8O>wGL?06}A5ax2a z@)``#OmEJD$jiUekE`|>J8y-1J*zSowz~CxN7q3(?+ka@T>6akQf2H=GiecGSC-h171pXxz1cb(~=nth8C?tmOl3k6&yY3z8a<16jQDU$YZx5g;^f(TGnScIO=$ zIXS1yP&roxA}AR_V7~vYrvMq^_~?-Ye+QBE9~ErOjCfL2Tx27Y-UEX*lN3dpx0j4= zEsZaS6nMOpj54*E@OCG=Rg>zUx|9!U!qIQe#*U;$yqLzBriB;DZr)pfc<6xPNPXUm z+hsT6NZu(6Wy3avGP)GeIL)=7OPz?5AQLZpF;Y0CuAxJomVx54169%j(jEwt8sD@n zy!52p;Qkvt#Zd|}?(gF995XZ31kRvIfEvDEx8ne^57CapTA6bt5R1zXfML)|$zwZU zdpgVS%jI{4cgY$Mr_6zU-3UTDCn|<+t41!RLO;#l*grGAyQlRTtI7pd@qc!_No!1d zBE2WFGW&bLeSUTT`?f)RJ-DZH{G(o_;Rw0Hiz~Y>BLJ48TVohCC`t^Nhv&B33blA? z0Ie8{{zLR#`g46f8soap$NYZ9xaHJGYnB}L9FuK0ob{Iqmorv{L*lQJIO&}roA9Sh z@02z-69%=4)4V-VOu>V9h6}%K)2X*W0D_HP;v#K|n6(3f+-b+T?CPuCT+=|9m^Jol z=<;GLwHyb+zlU1rAAS%@2KA^c$0|Ek3^f&n&6?m#=N725j1;`N1rfF-y)6|1-io-2 z)VNN)Ku)gtYG$sf)An}g1KJaabkIeIDiqQrhZ61}JP3vTIB%RDst|f3#w%;|i?ntH zE@1`CcQ@#t;MItTKmN>zV!~xxR#)3c(#ds>XhjJ5Wt8`LH0ain`x{5e+V5C+do}8o zqQY!lTp`SY4m3r1BtipUTG+XMJ^sPncW1=;b79xIcm(~mvzuU(Ql;94Yo*$v^ldUh zJcV|y?~O=q%>kCB%s#k=H5`%Ghl<>mNqx95p3@>fV3~rnef-BNnO4wCn407yOA|F% z1eLT4aq6}t-`^TDJw-;M&iK=sF?Mkc{MS6+8|=sm8Ht-ouS8%J<t=BeNM+x2@ED|KH>E^vGJpZuJ)l6P!evAsqi$2N5t| zFxBgh9cDjASB6afz(QLE$uNcp3<&#LGFLL7svh&yJ1s-2p{w;?3bZmUh5WTU=IYF>T=t@_WXpff0gDdf z*h|gkTzzR*m<8NDg$bm0$Veu#uOSM6aH=mels|cb)F-||m$%>W%fHX(_WhLL$V=Mm zn}RRP)SPa=#l>LAmwT~*b6VDwZF2gUf(WO|g}n|+?F1-bG#j#t6u8{^rvPb05F>+2 zy`4W^$p!xs;*V*Vw@WJaO24$m-|u*-i@5C>0R?SGz^%IO{G5acAq5}(U#N^09i z(Y)w@eveL|UV&s42n}1O9q0(vNsJ8LJTUhNz%0sQ+7CM# zPFzT*9$--}cp!zvnd6O{@C6m9uAoi_bSK!*&@ znwWT6*4q)aqAkmObv(8sOp>qrW;hOYsI;wR!?vT^3}3!-A=^@!CYjNvYrW)3{pgFw zM#S)=cFreai6CKH3yIsEv1r7|j^Pm>alGCA2aM{MSPMa@GSK53PFj3$v5Z<}s~g%+ z!7XuN)~FjgsrYVvxrL7kM435TTiBL(@)hx9X5H_6&ww8j_i#_I*_H3FC<)$i^e2&s z-cC8FHI-}A%y$R9YhyOA$PAXT5F)HSd#uh$mF2jR{MD#DbIuJyf9aag!(H^NbMT+p zp(WR#4dMu&n4^cJ7d1@e0%9qa$Z{lE*IfYu$J7G1{!euRMB%Fb%}M@V)NwksDe@8^ za!&&R$!1sL-)~c~-g~t1o^miY%1YB9Yzbp-%A}Y5%&?cS4F$E79hLV%3XWecA>)&_ z^3M;r$ov4jhf7`i{%Mq!!w!qLuYjc`rlS#t2O44GMofH5rx34MGMx1SWJcOAwgO)0 zB4-XG#W!e$#YdQJPJ+8zg_LxSWGJ2}A?L^bvxg!_JX&4rUf%3TDJ}8#-DK-dIg3jK zo28v~h?f+Y9aA>kC}z?sPru<+`Xng4bqnIz{1k+1wE8V)(G0y9i4@P2RDdQl)D=Qw zWW$HOg0iX7uBA%Bd}RYcD=Sy=b7j-s(PGwqyw;q->RYa6{~VxO9yukXvACoY4s zSN>^N*QP$Ud^Kl^dZ;5?{b>MF`6#E{O9W>{J1==tnKlf}dyhu(p#sN6QwNtnAevVu zf7z7j_p4g?`AK^MkJ3;EDInBdK7DVxBHTn6BvNz$7*qi+rExrA9oz@Xm!la?3+05E zs~AW8b3xfTCE@)#^UoBLY5IVNiZ^oY9sJsq0WqJgXqWES4Ewsm$^mI3 z0JnaKc^J0cm_j(D9*J>oF>mqWGir$j(nMQCgoCn|+nyLg7eI!lEY&2<2w9?kGv1z2 zGUOEM;GW?K6YN9vQizjxf~j$Ljvy{xH6rjOIN8y=viJtB?5w;LQzBu9llZT4rU@s1 zCjK&$b^ou7{kJYOXhUIgB2VPBj+UQR?e9OtsHu!pmTM59tvRugDa3Z+&?2r>QhA?=B_<7|?7amfn*9Y(9Cm*K4 zS3(%DZWj4wx6v1Q&p6YPv!lH%GX$LjU=?kLF#&hfv~znxB_3)SrK>8oJbdg_qDYLvYgWi++J_BDm)@Bf6-Xyf7r)aPkc4b{(li-0y1u(kCN zkmu1F<=kklKQA^jpW_S#zGy01%(^DeH6|~v(@8$8DH5iN>0eo*A^)Y98PYfMYB$Ku ziy~BM=t>!Y=B@8sJ#uFELU>7MYt0Hk*Vw=dnQTJ}cgHlhb_-3-;ujV2qa1Jo)t3@= zQ!sjm1_%c#=ZEUzn5D`D-+8XD%zrW=i@!sXFJEk z+D8nxPo=5~6#!rc4PAO`@LN=jKviE%-q3tlQeu4(X|o~X>r z1En(?K3Sr7%63vjXu##^pXiM2UEsw}SIJ@Ss3Bql>3trC(T?ZGO#~?5_{FK_(P38n z-l03--o^WZ;N>nKkWi_jX6hyI%8WLG7bZ+;#o^}-eBghb2>*n1ekv5M$sxl=p|3w? zK{q^;b|n<)8)iZH6uhTj>yj(}WX6l<<4q=aPyC9asY>m~r6d|N$CAIL>I}5D2MS0{8Pww>x#p&!H7@bhcC079DhM`0X z2h5<mEe(9g2kmH}>TE|U4iOp*jd@EbpFkB5g$d9%ZX56J{ zaKo{=IDJo-7Nshyfo@}8{`^x!8?CVt$i6p(RY{Of-Jf0jEcA{WHKJsd2qF`l&qcA_ zi{+Y}|EH{^Nwke675H)qlP{h9!*`C-9M6Bd;80{>`gB7htg95@oTOW5CF%p_LFCP` zm>Ks>(p<&k+ae9x4qB4s)avORat~1k@jG|k!EDyxNpL)&Teea``Kx@jDH~Negug0w zVW`Abps+5{f{B|FC0r}0;a7~urVoqBA|ZGeK!e+R%io#fNj%hK&$s0@IUHtGi~Ee& z%Ziz#=_fj~f=+$@ZJ0_#-sV*kviW*}NP6)cEL8MNY(4!s(tv}IGwzZ`>`zR{VacuA zU<*dVZ?H<+g+Td=bz6SURwA9}=uJ^hmoi01fbkdZsx~1D;q|&edLFN3qQqxKrFqgGsVXgaMiRI z66IU}oN>JGf-K(^38dTUSrwEFuMNo4Ig{<+StPD9t=4Zd1Tg}1#wdqy`b%E%%6&KJ z=1$e;T{Vb3es@Y=fQqPFug7{7E)?(uI>LPGMA)(@WH!n4@)|I%tzT?)F7)Jrjc!N` z*gW!)iZV0N;S)ig_X>Y3hL0|U<U+NDalw{P00RoNmE?z}@O16rJ zHCXhcWnWEqH;4E##xC`2UQ-k3J@*XRpr*n)c}(cMGFAGcFx*=65;|%9voTFvr^OcB z*8NHG>>5=QDEQM@5ea*uJTE>Isl{_cbionbIEj|J;t?Hz;~6V4TpQOU2%Dm#5}*us5>6ptplo`35nDcN!^L8P zlF8_Fa@#q|&Qq7Cq$MzWfk?gBkQ%JLWFo(-h2#Z5q|M`1tsN%y_(eEN^L%sKY^1h( zCfNeU{GOihEidU~XWAH%CA%HJzLTEXuPsC`{!7cTy9|AvTgF?2PNlxb&e-hc*J-aE z+$aBaaeTZ{@;06+RqQ6kJUbUF8(zF5gEo;PuJo}z2#Gc1z?pvXGV$@tkvM?Y@}+@| z-`i1QHBYS$Fdd>VGnG$uuuyuH3HJ_ODb(G8`sJ5qx_QpX)oL#zcT9>ETmHjxS}NNd zlBuVA7wl(=eJ;VwUbW%tAHyCmiEI~-Y5w)&qgUzcE!3FDmJ-pX6IBw4JF|!np{?b7 zZ?ypyjM1Wi!WrJL5i$>o+}db3^fiQn0IR#(H%7+UL)Fz)Q`*Nbnwu@|{BB8Tf8$+s z&Cg$ip#;xE#{4;HJOE1gmW67Y+GisJrd0U%pg|b_)UluA0!dBGpAxxkhxdj}N3>?s zvwvO|k*Bjf{p_HXyVp&*R%b;NY$dnaB=`(#ndDT%vcH$8JSBrlL6SBu3EekxtV7&T zA@pUkDp8V`jBz6Hfi~mfXHOCJ)-&=swAK=t7z*Chfa3)@y80^&t8W)?iMtGNtUskf~y_78_-HthmT;d zt33QwS9g!3Ya*5f{1>nM*j=ir_``Ofm#wt$Z^&a9RAzR(!HnVnt##UqJ|qn&uf|Mp z713ASS^R^=YLxFuxbs#ycwu^mx`C6jAeSR?DBFCN>ZSFk9iuLaf~4Z6us%YUXjjff z==1?!YT?l6o}HycBI^-a9{B!QC}P@VSS~&Ahe8}g3U&c{r%OT(=<)mb<&W&&&NoWD zQp$x&Bp3)uOiR9tq&IxhBZxpFmG2L=B7}I!s7>~?%U9O{EXkokVp)!=Io%;=Ql#IF z_vRdwoTuivN);6P5Y=KbrV*K5pETC(tE|Q?yB5flq+@wS^~Sx=ZtyEx3Pyw})|}_k zUSdxU<4JH5M+Cp2h0*zQkXy5x}bLopvOn$+-11%B%j&F z$f}x4+^S_>)yvH_{aA2kWiEmR7%z%60w*fAMSkqJEPSoi!CYd>hUr%p^ztAK6ME6O ztiBJslfcd3xDYls+R!@A7Wx6{{W{W)6{-|h6!CO`0}Go)`EKUAT|*MCl*PDn-*1@> zGu0UYfKISWGrLhB2^$JkGuy+MSmoT$@|o zlbX&7m8kj_HJ*Q5DJ$|aMWA*(Xf&=}_=>7pdcVF1m7JdV=jq=fh$Qifn(~U~Rh~{d zxZL@U9ZLE2TcY-%J5nwQ%B0^L!0LLk)2>9M#_~9($}JypdaXl}pu#4|uea-+o9D%* zl}Cg{^vUl>D3SQy5@bkFi)^~LL?^;$GwzraelKM3WNP)=D{>7e3F27fmB?jn zCfRp=4wYVjR%o|Wyj&8JBey4SUV&Ftd-#-+&-F+WRcxITRay&D>;>~k_*$Y@vX+mG zgV~DpI8dRpFX(2{ZM}XQ_OGUQ^t?mXD%*piX=qdk?k%!? z9o6IeHA=hvkd`1&H}=a4i6h!V+pnxVN8B?B$e7NoSs~H~lI36Fq)wpZ?legWqreo5 z9fDj(Y>c7JZ!6fus-- zER`4(6GQ^93<`xT={^?T7W+P1dM6L{&nPFTFLWXFLse*u-~VZ)H85rLMfWflV+@71 zs%5OtDZ|jC#rRpZW>VJl zR-b|mRVLej#wkz>HQfn&9o;~JZU>ps!;L)Q0)BY+k`cj0`MTlAGhN$*&(AGmbn@z- z15^ldFp)>(eaaDq;9U@kwgPp)BQZ5A>pIg2p_CX+;maZPYcoWtPNW&h_ zMv>P{gRJ}a_be^gIGa~^G~nqU`lF4bu;#A3LK(_P&@$$iUM%9L%E)TT?4oKLC%G3e z5)7AtI1uc_P_8Xgs=Q$;f2}oUj=tP5sJ&x;jcM`qu4mgLVu?*x;kF;^b(5gBLH(2&MJ2M1YCMn#~9i{D&Cj(J<{r05aUW+xe5G6 z^+?1@E?09gK*>0bfhR+F##LIdWwJ9rH+}I(G%=ykW$qR@Vs_0_SJv4t0he8v*06C^ zha$5^SFIwe zPKzMb`F~7X)sZPia>P$nVVNh)vAu`g#QbF$C{tp?X1kDG4SKEpQYp(x1Hy<+->GwW zk4&&$sN{w&kL*v>ImS=qW$Hlv0hCsTU^eHCo5{~2U`E?veTfyIBxM2V#9~)O2T!r; zQX27occ)kzTOf_2z&_vg;Rp``VR|=w*p%LRw1mA3r9lo#|y* zROum-6ZhY&>PWeX24(zn&-pX7C;d_G1E3=}meBBdY)nQ9Rp9uJ&nk*=nA*o+40~M` z{=wGLOGh|+R@JHZg|`t{FOF1doUYXE*#jyUQ<3oh#_5*&B@!N$;TM|eBb+On6JBa` zg6i4b*>U#UbPV$R&v|tgZ3U!T+~z5)PYQScFPKFaJe?0zWkG_HXE;333>_jNPbyDM zZgxW*>t03F1!cu&1BU&89fOGx95?QF@MM4=1M0UV{vL%CGkcT^Kzs>8)gHkMVK})h z-N=SmqHVuNmQ1IWgdle+xutID?KAqw3zfl+q0_VB*zlGKF7!WE!{m=`@!`4HQ1-r@ z<3@xSeY4+9%s3E*>Hg-DCPcJc!P{%}ClCk|{iPj%wUQP%MEP2B`c+DH*_Fh9OC}-d zK8H)A7NFwZ4ikDZqi`31#FiOAV1R$r zf3?P?GdkxDN#eqb_fJ7l(x()I=sKz&^m~DX8z}Bs7qp)!0}81k@6UopGtJF`U6QU! z`qJxLdd9|F(E_7eVVAFvVw;K)Y_iXzmNg{=-?y2yK8d zML5i}{-KPFh4nw!=f@Xm^omkh3uuMHIo0?O7zC;@l07?n3SupodXqBq7H_(7u8#2- zTJDcJrH@+1&s*+2R@44TYK|73&sNPBVIDb|V`G8hJp<#+!~6Wdcmj$>-c)7gB1wC% zAo6zvm(j^p{1NRB@V3kuD&b@pu5?NEw}$y%`r5KlbV*M0UXU|f`$b_BEOmYuT}Ya( z0k?}V4)se%?61z#v~7}9J;rQNk4IYU-vDQn_!9BN)tnOV&sUj_o_E~i3atx9Oe3oz zaC@AjlpQT$T7LKv#l0X5oG;<|%ECc8->K|VkxJ)x@M4_k#L%AlC?6>(vo9?@3Ytx> zh5!srTOPD_;6JEFG$G3PkgjBs0ax&#+Hx|Ef7J`sK*(M?k1->40vgiH<0#T(VMt_l zbpkt4nv#Iw(C(p3uZ=&-ZmK3vAn8q@-^!|e`z9CIOn&tuLuaAC)a_; zOE0$ku*#`36OHe$TkbI8R>X8KHFg;FK3JM(N9~p_+OjDZHJP&zc{abf&!keN&5j^;+=6r^gOq z*rgm+akq8Upg@KMj%%3n;jG;;_Xhd8JB{!LgswDw$4Nq?E+`+{d?)@1QMCJl`vQOOs6mIo{rUPk@_>m4q6Go)r|6Fd*O<YH~rt-1{mGgN6>POlUifswL+XWDWDb{i}GjfuhO`Xcsgl2^$ zCU=0&V3eINoF?!cClL;JyipJ}DyZ4`nF8Gp;;QybDZaRl3SXyd%mv^$8A=PDzBQq# zbP5n_vVt&dUy~OnIGSh2p^LxTqvSW;OAO*`>C^*=td<5SFH%SKT^^Z9lx2e|1K5E^ zE(w%M)UKe`XUMc2f!vbf?wjr+!Yd`V9QMO)@G$=9sx$410->Z@_Y-sZF_$;UFFH3a z%XlwUPKmc!EK&8?x}o}0BFxvH1ClJB9RA9OV$a+?hbMdq$0DX(%DEj|H#!@2TygS$J9b<$9jB$fzhenSxWg(T<7AK_QOB37)VB_wJj=365XhqhF# zJBzm~L6#_!s_{>)o1X)|ea$zZ5qblOnPXQP`I>DeG`XN8(hAcpHM@TFp6$`-O_Dpf zqR9t#N|~H|s@y7h*ZcG7yCBQ_8=?XpR-d`QVtQ@}%P&KBYDG<}CTa=9KwfNl`*T@2 zwi@>O!B`yji-fK)P4N|~!5P}q&lUnp1&2rPA!g;mxQaGKqzW4pq=R7aDVRkNM~y)Zg<9v&URsCmt(1Ezec>SJJg~z(6Kh<_f^Y`Ru*zb zn{@>pwa+{ggYZIsU$4_XW7I4egzVJ0Dp)t*4&pusfP}z0vk#WpS%Ofyk*pSs ztf)j=;z!Yh4t|9U-QS4Q$_FG!{>p$oo z`Q%3_M^u^Cn+Q&(W!sfYHI0A1Ar|s*Nz;AE`=|ZQzW*j;pNtI7xB1q7)03t)Js<*| zr7Yxb->KF>8|jyw%;htAoVd%4^>OeLb(5l{9H?4trCEX=9Ho*r(pdYZo4chxM^S$= z>0pTCys<9o=ydQk_|p&u^Jj+CDqWfML7Cj&FY(ShHXMZs@eXgL^i{Exw{eV--)5t{ zUWms`*2s_V1=;8~V79)Wk^9??nktj-|Cutu!Mu*985E{oLw#<7!|wsRu!1-AOMwSq zipO5zuD4t9#PfE-@5hTc73-mrQn}lU_?2EU2LUI0Qx^BxozL}wPmPE)NOOUR6vscSY2v=;_AKT zR-=vK0ruLC(#xkXEzhV$3Kuqq88g_9yGsa3q`48*Y_g<=a$(_3?Aj-|E*oM(-xuXw zAi`1H%)-hh=;)L$&2iMYQVHhIB!{|J3Fc8m=@#?N=3m^Ybi)lp+lR@q{w$sz9Z;9A zHDGl?a?Jc6^a_6uoq&^QIk@;_2xYJ}WR${37a%hk_03kC&#ikA*HX5f_;t z{U1>EdekQrYOk(uqWB=sEo4ckUO>(J?i2L8!#nNhn~Y1@xPz^=*QcWi(enD1Pfjny z5u_9D6T)dR(FO7xC(S=HzHQ|t`ofPu;@+;>6H(n4M9QaI64+hf(>n-93u5t-(wTD< zs)hGSL0f(9a)iidU@BNtA=r{7;uxoU@8>Wb_|bw^{fCY3X(a2B=bPHbj8l@4R-UsP zsuH`uhx^EQObq3*Q!<&v(w#W$R}I3SuK$tGxF)f|vTHFoQY?H5mcY1U;S}14H5Z9u zcpJp7E917{=pR;r#D@-9J;^Q<_ZcrRM`W|8YcP|KH64xAPLS^{y^Pu~xFMJp7~6C| z8Ciw%;uTjvz6-whUUpxQS#0bev|V~j@sOO|HLi+V%UE%q9I0UeZ;9jvk3Rh}vFfki_4Hf))_-GcL#E#s*Js7*n_M@k{piD?Tc+H>2TamQF@F#Pm~Zw}87H%pom9zeYB0)hP$%p-5zLT`3> z@YT=W2az9bu;O^0OW>QaQ{En(piAB*(4Uc|xj?;FfpW$RyFEGhkMA7;V)0$Pw}B7e zzds*#&*mlYG*QW{uAFrb#}aV*+|di_L4XD4GHY=J{6mYD=YZ};A~Rl}`EI|>Q@Yq-cc&I4l?F-TzE_T1s zv7Q6u;KTP9xCZ;9wZ?9tcc4nZec96`dQMY4qqW6$a8cCJnT*p(S>Hw|R8i#12EI>xLu4`dMWH%Dy3iB!{Uu7|JOVXxbX5!Knyv>zId<=mV$J!-Yt zUH*J!0JYYlcKyEk^MV^VhTlBctY+xSwyXT~<6~uMRpHV)TZ`1|2X@u&61+y)rA*|< zbj=MugCs9Cen5>}0aJL__|3h|i67#)viAb;!|HBnJ+2lOv^3NiK6T6~NF;Rg5FVRo~0{eJu7 zyV97}{pmly;nn|(=GpF&YH9cOeHQLtUy!YDl|~L%1^@br(~Ei`TmGNgfBD;hQ)P(W&_JE2C!+cPJam1F-a7pcaIj$d*7RRA)p*YabN>T$ zHhhaV()%wiTg*TQqIBi6`wgJHl}o*D|L1^afimDfXz&Nr;6FAlw*SZG zOJhWL*GI-bh%q!qZ2gyw_-_O6ji_DsZ)+3oh5ttp`i3Eg*@nv9@m~t0EQ>iY|4YKF zIRp>#AF-qDz_IY~|K;|!{uFR9+1&7No^R_90SBC`|FV4YF3-7X!u-!rXj<9(Wa@uW zz6pWUH_bi&7hE9e0_9uQ|5q1Dy|z2d8-8$lwULc_qzUH9?rj(g%n&X)em~*N^bLN0 z!NviXWeMoM6CJOaEWZcSHF>?y?FV*QZN9`Qij_>gb3mrwoGQZKmz5Rf10k-E_l}3a Pe@@uz(e97``RD%ugR;NC literal 0 HcmV?d00001 diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 34c6ff36da..1dd5f9e2f5 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -39,7 +39,7 @@ class LLPRUncertaintyModel(ModelInterface[ModelHypers]): - __checkpoint_version__ = 3 + __checkpoint_version__ = 4 # all torch devices and dtypes are supported, if they are supported by the wrapped # the check is performed in the trainer diff --git a/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz b/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz new file mode 100644 index 0000000000000000000000000000000000000000..505d257b6b9dcd36c67362a0429090bf0afa6834 GIT binary patch literal 8562 zcmX|`WmFa4|F-E4LE_NejUe3(Qc8C>9HgXMy4#?;8|icC5)cqj>F%xrocQbax1MLM zc`>s;*S@cN&%7MEL@caeaJ(}L!X0R5@5bXNWa;B&?d)Oi#p5T;YwPOigLIaU?N6fg zXN+xOh?F406Uz$?O-+>5dv_N{8P5qxAY39+^o@h~r-+H~dYanuVD@na_Vow+62|WJ zZVOpY^9!nrJ0o)?pDVMA%*Nxhg5%HUY84Bu&r-7wQlqhN=3)JJwcn1fe>jah6LO3KUaDTEs;PgOl%B%B)DnRJ3O-TYO4_$La1zjw0r~IUE|^ zZaVF`$tEepqp7=R#ExVW->0M2nr5T6?e`*+ohWfDLm7EYVx~wp2g}>795+wDiru*0 z;1UnxYe>(B@;owwt%P($#bnsknjyO-B|1W653x=Ye%GZ^2~-KeMmUc{Ht7r2AH*l8 z63xmzmYH)~hs}3!vn14yv&E}PaIhB-*tUyR6++((X9mA*6DS<=>NP6i^>}#Pk;PF| z5FFX#E*Z;tW&y38eGr|b9x0`;p zk|D53Qe0;-`d$4IZE8?pI%->6CwqWG`V;%{L9DtJak~`5$HDg!c6BzH67S%WDS7)P zaOO|fCUfB^+S;EBZOY_u0$--5MPmdy9hb@#tP7clu{J8!P^d)ZmJH##c)4^h`P3fq z26hvnPQ+3va-_txl@t&3A92yI$kKR|56*hgq|o~UmJ(vf{R7i7zx8497WyRoiYhCXKoTL(5K9^s{~cW8S~tr$DTjg znwJWZ>n{{lFvk|tlp7h7Eo|$XJru{%QoklU?eEjOuiLWdnI}2~D9?S56AF;Pn8}Hd z>+&QQ*W!ix8~7)iL8!YfyNEvVfmVP{>OF^ zs0m*}#1p&6{T5O}W~o%xq;8{V#(4TJbT@O{U#Yf~nMXY>GbAe$pUAIMjYEQXB)Q+y zDVtE%Uems@&8y!lO23lsQ;07$9RWVuWs^{z?O~ffg2F&u5*LOmQMnqoN}1 zXk`rZ$D-`crBgwDqB5$E`?w}^AJOFk4Ej_lVt0+3(afo{9EUVATDQu5l6v0Ie&o&l zLE3&eIinqD=&oqEjIQE_ld?+veS_KuW0tTb;QTZFr{zdH${quWWb3 zNcE6a{Pg};xJSjnd0JGTlDvfLTzvtKO;ttznssb&TdSop!HPn@}?j6q?No`aUJ;^Q__zYKP}0KB<-8soQvW6H9~MrR^O`Qv}G+ixh!= z0O-=*6RIB;o7^@zwpdel(Jf#;s^T=GoMlu_tnq8kAKs7CxeNL>M$^_Sh23R+DU;7i z6=0oaszilp#z;4OpVs~X)Sd)aGT-{}1di@}M=BfoB#yya8uU7QS8;Zt?0|G0_jV;w z@Q4Mw&n)pifEBqB`)Spzur=vf9u9en5yDd3=sE|I+QE;a&jG#_C_gl%xXGXA z+E@U}jK|Tj+`i*^qBu8t_rj(q;ti8s7dgE~oE;LVEXOP+G#h9fmWkTrt_WKG<%ASl z1`RDiO~l@~x8)eLia_R>00u_8!$PEP;9nnRt6@ z0mYm^GJHftCmEPVykUhQL2adg^tx>aVm1Lu`LXrYA25XPc$HqIjFbi!%3!!b7Q;y^ zl}MRo8T(L1a5!iS@T{ZU&{YrK0n{N`eMF+twpOrSTrwh0V7`#^WLp04(uZB7VEoA= zBeKNL-?kxeX zOG1i%WaT9Hm*`3sV~{??(-H+Qd_|(JzF(FCz?B2W&} z$}@fjM#dlx(2+tF=GS>4%YekAZ87nJq^>H1D}?u@;bFKIwFA$uaY7m*j3hxTyl|{= z(E#B}O!?RmUs3WL0B^CMQRck$G+sNg4o|qrqOKb^x#5YGM=)hn`*m~vjEzJT@-mVq z57O)@`)d}8p%N;XTmzf~FkZ*5zSs=zL4%i*3x486IR_IE@Q1%WL#W|~R$CvD1e*8aq!vA1bnR2+u3o4C&em8&G zD0X^`BS>-VF7Ag8e`UKd6B~$$d#5!D!i9W5MbJ+22~Eb1Uo6fi(mtX=aiOZ95P2S8;L9 zm12UhJ^4^QEL)22L|J45OUaUl95gdv36r#3t%+B);DVmK?;MR2p4T@R&~Ev#sJRWa zCsSF3M0u*FMl9W_PgLT$FPIiH)O61yw~GZrBeqXV9BLTfHHHB#kcVNG_7naj41dyA z|2WMK;(Nn7ndJx5S#BtYFfBDaFZ#iZtf{S%w8O7B_~F^Tt-1hp+er|MF{baCQH)IH zN^@4*&Vjl_PKF;;zBpvU$rAa@wFo95#>&y=ll+j}a5Z`Lt?sJ-9o0n;`W>nQ46CXT z>F$-%z-v2pM)mkO)rae21$+p(R|zRXy`RnG3iva@gDWu8g&&^Z;^Z7VPKZ7rSN}`i z1-YjxAu`FkE{l6g^wc2u1H?4opYpO#8yYC^+f z<*tF2n z3uSpTTh8c?9?SQ5;rXZT^~L&iaz&@2q`vPR1K+=9@I(fXUynwuY@`qD3S?0+p6KtrmZJIA_1%tdcy8vyaju6qN)PzQG}V6j#t3-C%4G1%y=H6RXt|YP z1x?6Zt(zG^4r%XRxSt>8rCn3L4R|TOGV)9awS!lZ#*HgrMY=m%ya3ri#npDX;ov># zFPfp1yQPMat2g}OIUZB7So1|W(|qaJ3sqquE6<-8r2@=Zo%?i_1gc!g*&Ve=h&pM9zB!;dEZ za8{W2Z;x2@4#zdWEJPXANCxy=R$HA4+nkP0fq&(`mg^$LZ9tU(`Q17@YQ~~k@w!~3 z{xL_{4)k>TG1H>Flq8DP&)|;tN1xn+z>lj%ejxd)0veRu(?13`m;6BvWS}@2p(ECy3VloHcpY*GF zec#+ZIVnV_$$8I>_^NwTu(MTWP5t5&sM;VH*|M25+tkuNLRs6}VdSdWz%$b}-##MF zWo}wmi5pqR?9@9@!@CjL0#6d*4Ca*EX|0>cMz9*hPV2kX$#=4@I~?uQ9w%Ye%DQb$_h94x86DyOWypJ-Xhux=Fm- ztxGVeCAx9#sG;fa`A5of{04i-2~wEU=C`-K9Cwt!ovs$IWS~yh3h8-rQqW6JAjUaC zZ42olOfq8dm3=amjFmm(Y?OKyp}_hOI0((ZhrO}2Mr3`O$$o_DVym?? z`D0hEH68k?CGs<1#7_%#ENrKEgHz&z3sAav_>mb^8Xa7T48cREd^>#ADyd>Hbmh%0 zx}pq~WQkzCL9y+_s>=<(YgsJf?oVanF}%Ap{8?n79~qb&VS)LB3Mp{G zWTXD8_(O+q^-SN8(FMwhlof2qMG-5kkgXaAg6X1&&|`c`)0IY<3@r+F6RrrS+Z3z0 zg<%&)6}C`A$!%%SmK?!Pt1gO9@JP1m&kwmSik2KiA=B&#j7bsM$EDB4H}Dk4Dw~Ig zVo>Xb_&mCtKxhBa&CS2=ov@EjUsu(0M4Y1M9wySAsWXCTtU-X zKWAiKsH;sw^F`J9fOPjlaB`(fQhBjG2!}hzB9dAI7;}gkH-t22$Y138u)|POv}U=< z`$>{)3JHS}t4I!slG~1?jqpIACPPZm>bJ07{2o>e03&WI9Kc|w)1q>Xq(3-AE@~f; zmDa1H{AxTMi2^mPSzd+sURP3Xv0>~g8bz7iF8WF~4=Fl6lISduSo^J#YY*7`Rw-%M zikU6VOx6qclhTPzL21n*i%Vq|Xa3t!z#g~?Y1D2Y&P3{CWtHHap7dvD?$1X_de~(Y zvf?J=-0uvpsPzk_m6IJRh^L#Abn%mn!FfD@xysZXjvD?DNjgpZL`RJ3Wp+Ln_{ai4FoxJjnALvL>8%IVIYA zVl^0N`lZGdiaWMR;#QCdq8D$RelQa1j~oO_cVdv4r9~(pSrBZJ%2$pY@TcnQA1Dhi z#Syz^G1!JHh7U(#FN)15=yw@!+iEr$eF1b*fVq%Y8GTGVEJkeg^8?x`Pn^~3>hS`^ z>1UC4DF^aqdUFn{P;NePI_3`=&x%row-SjuVzZ*{YjXOWdTAN2b5~7X|Hiqej#)$b z<&7Lv@PottF_^S7q>GcolD;WF5tnpT3GaizCNJ`8xjbgZwSDl*8k6fimzNgBB%v%^ zP)>HJ3p*UOiqbKEDW(LrCao6ntA>S&=TUkgdTCxW%*u(ff2$Af2OIc9Nw$AM?%C}wegCxH26 z#rSac6*qIi&-MeISiOXj7DRaQ;Ks-+*GM@ch%bMb0zW$n2zwuhjqx`;>TyN>KymW} z_{5x{6bniVfGhCRqksr2=&(L1ZiYp!#mYMnL719W`2a)_G8ikUkwiq0mrXD|enu1o zt{Wi^O(M83V(~zpl?Z~{^!)Zv2Mvys)KB(@JOY7Fwt+>VvfzM zVQ4#rXbOy%vsTbVD^)Z&X;RqBzXd^2R}94=M89lG$;3L67@J7((7zxP{lcg)cZ*Tn zY_M9Lc2Gq-Cm5Cb_h`m{?)-`w7f0yL;efm#8DdKZJHt-ceF4>4;{HwMFr=~=cp`f&qU5^3H zu&6KCmW*1K=iABAL#IlB&b)#d&3>HNei!BxjUv{%>|W{Z7dO_HrR;M{>>uNtS%scG zZ5j9Lwbuu;(_gi6E;l||EzzJF41FhNeb~88Jy7!^+VB#+SQlGw9^Z>{$Pr=dMV`BG z9Uujtp$h{q2R9Bx9tZse-5aR`)*Xv$Dl301AY-NbT>CHOb5^$sVSTRgauxP!PgJf|K zom!3-IAe$jd>w8`9gTub*F{`AAtu&WD!}m*>%u~3Lpjp1IW##l29J5Ly3%pzsDN8k ztZrxXFl~3P4V_Ngo%sGyc&&RDut9EVZrl7L>}Dloaam`|6RvOR_r!`#*F#WS8*>SS zs>8hAPNG`&DTwZGvHHAMitkjRj%p7>>JjKV}7$o(hM!IePi3Ns_v5hEL3Z%U zjioj(OQTnvAH!X@QRB~a5IfLJ1H{%yTu(kdCag}rSMGmi-3ZW_&~ir>BF{j_Q{@OJ z$UPENFhx0wwHv{#%#x+sbFVbqU-l`EdSQI}iyM@w&G|I{`e+F@Hm))|*TebLSSsMs zn@v4YUok86 zSzF^Vo4mY&f#9_3p!}KvYe%F|B_yOu5xm0@zz{;I(Q9{2ZI0qOJu)s-fzFr zE`vFQ&>k|Wds;VHSrk0~hX=k_Fzb+6=;#Pg;qqzcs%_vD5J?H`JIy}~~FyA{3rcu8!&|?Iz9!rs5a_Y^>X2e-onR>lVT
&p7(@}d1Hy!JEzM{8V z;Pd_K***R*VY37)U0+%nsa{RG@`f53>}*`JSw#N=!mm?UGof{(pbizBt*2+}+arga z4~`P>#?vm0WB%jl6Vo^Mprb3G!U-xVC?hOu1M<97&H7j|2AY_CaQ4PqZLV}U{3^B4 zdqcJ>zgrQrJZ+9>n4hsgScZlzwMjp8TfvVKoX7d?|*9*E7(X#0luE3si!?Im~q3p2s~^QYs{HW5+VGyZ7RyA|{tRP!M?{P!`rl z_I6#&n#2>4B9E+i#Xu}uB_Ai*JH)Q9M?4^xAAw&>&l+8z89@7@Tb98Mf#K3fma;vKFA2B3lusT}+uzb=hrj zX^YU+G|lON^d!s{Xvu*7TzwZxypDSp^o%H9nH6qA6nen)&0PFodsY2Z`x^xjeU^FD zh$%P#vG-?|db@$29}av={Tik+8s@mPgHQvl_skfs>wKd8J`7AmW7_GR^Vou0#v!H& zCd=<&wrOf6H_PR;jpdb3K{*vy%X_qqA6>V#4;nS3EkH}d zVJ?67-d~yc)F7V!Zax}28K>+QNe1yE{CSKP;LcF1OH3iX=uV1%l{ha8WmigWb*%I(C7|5RQ|Vw?($DMC1U|vXbYqN zPXN<=8AciQ?VpNgG3-;`sAX@vxs39WXkn MU*aRaF-S=N2OS*u*#H0l literal 0 HcmV?d00001 From 847b8f530213afdce1c8fd8459b790e168874c50 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Mon, 9 Feb 2026 12:41:51 +0100 Subject: [PATCH 9/9] add checkpoints for classifier and llpr --- src/metatrain/llpr/checkpoints.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/metatrain/llpr/checkpoints.py b/src/metatrain/llpr/checkpoints.py index 1771815ed6..8b69420e8b 100644 --- a/src/metatrain/llpr/checkpoints.py +++ b/src/metatrain/llpr/checkpoints.py @@ -52,6 +52,22 @@ def model_update_v2_v3(checkpoint: dict) -> None: checkpoint["best_optimizer_state_dict"] = None +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + # Upgrade the wrapped PET checkpoint to include system conditioning hypers + hypers = checkpoint["wrapped_model_checkpoint"]["model_data"]["model_hypers"] + if "system_conditioning" not in hypers: + hypers["system_conditioning"] = False + if "max_charge" not in hypers: + hypers["max_charge"] = 10 + if "max_spin" not in hypers: + hypers["max_spin"] = 10 + + def trainer_update_v1_v2(checkpoint: dict) -> None: """ Update trainer checkpoint from version 1 to version 2.