diff --git a/src/metatrain/experimental/classifier/model.py b/src/metatrain/experimental/classifier/model.py index 247939aaa5..996dc95249 100644 --- a/src/metatrain/experimental/classifier/model.py +++ b/src/metatrain/experimental/classifier/model.py @@ -20,7 +20,7 @@ class Classifier(ModelInterface[ModelHypers]): - __checkpoint_version__ = 1 + __checkpoint_version__ = 2 # all torch devices and dtypes are supported, if they are supported by the wrapped # model; the check is performed in the trainer @@ -339,7 +339,19 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: @classmethod def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: - # Currently at version 1, no upgrades needed yet + if checkpoint["model_ckpt_version"] == 1: + # v1 -> v2: wrapped PET model added system conditioning hypers + hypers = checkpoint["wrapped_model_checkpoint"]["model_data"][ + "model_hypers" + ] + if "system_conditioning" not in hypers: + hypers["system_conditioning"] = False + if "max_charge" not in hypers: + hypers["max_charge"] = 10 + if "max_spin" not in hypers: + hypers["max_spin"] = 10 + checkpoint["model_ckpt_version"] = 2 + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: raise RuntimeError( f"Unable to upgrade the checkpoint: the checkpoint is using model " diff --git a/src/metatrain/experimental/classifier/tests/checkpoints/model-v2_trainer-v1.ckpt.gz b/src/metatrain/experimental/classifier/tests/checkpoints/model-v2_trainer-v1.ckpt.gz new file mode 100644 index 0000000000..ed2e5fea31 Binary files /dev/null and b/src/metatrain/experimental/classifier/tests/checkpoints/model-v2_trainer-v1.ckpt.gz differ diff --git a/src/metatrain/llpr/checkpoints.py b/src/metatrain/llpr/checkpoints.py index 1771815ed6..8b69420e8b 100644 --- a/src/metatrain/llpr/checkpoints.py +++ b/src/metatrain/llpr/checkpoints.py @@ -52,6 +52,22 @@ def model_update_v2_v3(checkpoint: dict) -> None: checkpoint["best_optimizer_state_dict"] = None +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + # Upgrade the wrapped PET checkpoint to include system conditioning hypers + hypers = checkpoint["wrapped_model_checkpoint"]["model_data"]["model_hypers"] + if "system_conditioning" not in hypers: + hypers["system_conditioning"] = False + if "max_charge" not in hypers: + hypers["max_charge"] = 10 + if "max_spin" not in hypers: + hypers["max_spin"] = 10 + + def trainer_update_v1_v2(checkpoint: dict) -> None: """ Update trainer checkpoint from version 1 to version 2. diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 34c6ff36da..1dd5f9e2f5 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -39,7 +39,7 @@ class LLPRUncertaintyModel(ModelInterface[ModelHypers]): - __checkpoint_version__ = 3 + __checkpoint_version__ = 4 # all torch devices and dtypes are supported, if they are supported by the wrapped # the check is performed in the trainer diff --git a/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz b/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz new file mode 100644 index 0000000000..505d257b6b Binary files /dev/null and b/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz differ diff --git a/src/metatrain/pet/checkpoints.py b/src/metatrain/pet/checkpoints.py index 020b1e6590..ef1967e2d2 100644 --- a/src/metatrain/pet/checkpoints.py +++ b/src/metatrain/pet/checkpoints.py @@ -265,6 +265,21 @@ def model_update_v10_v11(checkpoint: dict) -> None: checkpoint["model_data"]["model_hypers"]["attention_temperature"] = 1.0 +def model_update_v11_v12(checkpoint: dict) -> None: + """ + Update a v11 checkpoint to v12. + + :param checkpoint: The checkpoint to update. + """ + # Adding system conditioning hyperparameters (disabled by default) + if "system_conditioning" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["system_conditioning"] = False + if "max_charge" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["max_charge"] = 10 + if "max_spin" not in checkpoint["model_data"]["model_hypers"]: + checkpoint["model_data"]["model_hypers"]["max_spin"] = 10 + + ########################### # TRAINER ################# ########################### diff --git a/src/metatrain/pet/documentation.py b/src/metatrain/pet/documentation.py index 8eaec9c03e..0d72f0eb38 100644 --- a/src/metatrain/pet/documentation.py +++ b/src/metatrain/pet/documentation.py @@ -155,6 +155,17 @@ class ModelHypers(TypedDict): """Use ZBL potential for short-range repulsion""" long_range: LongRangeHypers = init_with_defaults(LongRangeHypers) """Long-range Coulomb interactions parameters.""" + system_conditioning: bool = False + """Enable charge and spin conditioning embeddings. When enabled, per-system + charge and spin multiplicity are embedded and added to node features at each + GNN layer, allowing different predictions for the same structure under + different electronic states.""" + max_charge: int = 10 + """Maximum absolute charge for the conditioning embedding table. Supports + charges in the range ``[-max_charge, +max_charge]``.""" + max_spin: int = 10 + """Maximum spin multiplicity (2S+1) for the conditioning embedding table. + Supports values in the range ``[1, max_spin]``.""" class TrainerHypers(TypedDict): diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 757e778140..0ad34e4a31 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -28,6 +28,7 @@ from . import checkpoints from .documentation import ModelHypers +from .modules.conditioning import SystemConditioningEmbedding from .modules.finetuning import apply_finetuning_strategy from .modules.structures import systems_to_batch from .modules.transformer import CartesianTransformer @@ -48,7 +49,7 @@ class PET(ModelInterface[ModelHypers]): targets. """ - __checkpoint_version__ = 11 + __checkpoint_version__ = 12 __supported_devices__ = ["cuda", "cpu"] __supported_dtypes__ = [torch.float32, torch.float64] __default_metadata__ = ModelMetadata( @@ -142,6 +143,17 @@ def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: ) self.edge_embedder = torch.nn.Embedding(num_atomic_species, self.d_pet) + if self.hypers.get("system_conditioning", False): + self.system_conditioning: Optional[SystemConditioningEmbedding] = ( + SystemConditioningEmbedding( + d_out=self.d_node, + max_charge=self.hypers.get("max_charge", 10), + max_spin=self.hypers.get("max_spin", 10), + ) + ) + else: + self.system_conditioning = None + self.node_heads = torch.nn.ModuleDict() self.edge_heads = torch.nn.ModuleDict() self.node_last_layers = torch.nn.ModuleDict() @@ -443,6 +455,23 @@ def forward( padding_mask=padding_mask, cutoff_factors=cutoff_factors, ) + + # Extract per-system charge and spin for conditioning + if self.system_conditioning is not None: + n_systems = len(systems) + charges = torch.zeros(n_systems, dtype=torch.long, device=device) + spins = torch.ones(n_systems, dtype=torch.long, device=device) + for i, system in enumerate(systems): + if "mtt::charge" in system.known_data(): + charges[i] = ( + system.get_data("mtt::charge").block().values.long() + ) + if "mtt::spin" in system.known_data(): + spins[i] = system.get_data("mtt::spin").block().values.long() + self.system_conditioning.validate(charges, spins) + featurizer_inputs["charge"] = charges + featurizer_inputs["spin"] = spins + featurizer_inputs["system_indices"] = system_indices node_features_list, edge_features_list = self._calculate_features( featurizer_inputs, use_manual_attention=use_manual_attention, @@ -628,6 +657,16 @@ def _feedforward_featurization_impl( use_manual_attention, ) + # Add system conditioning (charge/spin) to node features + if self.system_conditioning is not None: + output_node_embeddings = output_node_embeddings + ( + self.system_conditioning( + inputs["charge"], + inputs["spin"], + inputs["system_indices"], + ) + ) + # The GNN contraction happens by reordering the messages, # using a reversed neighbor list, so the new input message # from atom `j` to atom `i` in on the GNN layer N+1 is a @@ -687,6 +726,16 @@ def _residual_featurization_impl( inputs["cutoff_factors"], use_manual_attention, ) + # Add system conditioning (charge/spin) to node features + if self.system_conditioning is not None: + output_node_embeddings = output_node_embeddings + ( + self.system_conditioning( + inputs["charge"], + inputs["spin"], + inputs["system_indices"], + ) + ) + node_features_list.append(output_node_embeddings) edge_features_list.append(output_edge_embeddings) diff --git a/src/metatrain/pet/modules/conditioning.py b/src/metatrain/pet/modules/conditioning.py new file mode 100644 index 0000000000..f305e40f84 --- /dev/null +++ b/src/metatrain/pet/modules/conditioning.py @@ -0,0 +1,78 @@ +"""System-level conditioning embeddings for charge and spin.""" + +import torch + + +class SystemConditioningEmbedding(torch.nn.Module): + """Embeds per-system charge and spin multiplicity into per-atom features. + + Each system in a batch can have a different total charge and spin multiplicity. + These are embedded via learned lookup tables, concatenated, and projected down + to the model's feature dimension. The resulting per-system embedding is broadcast + to all atoms belonging to that system. + + :param d_out: Output embedding dimension (should match d_node). + :param max_charge: Maximum absolute charge value. Supports charges in + the range ``[-max_charge, +max_charge]``. + :param max_spin: Maximum spin multiplicity (2S+1). Supports values in + the range ``[1, max_spin]``. + """ + + def __init__(self, d_out: int, max_charge: int = 10, max_spin: int = 10): + super().__init__() + self.max_charge = max_charge + self.max_spin = max_spin + d_inner = d_out + self.charge_embedding = torch.nn.Embedding(2 * max_charge + 1, d_inner) + self.spin_embedding = torch.nn.Embedding(max_spin, d_inner) + gate = torch.nn.Linear(d_inner, d_out) + torch.nn.init.zeros_(gate.weight) + torch.nn.init.zeros_(gate.bias) + self.project = torch.nn.Sequential( + torch.nn.Linear(2 * d_inner, d_inner), + torch.nn.SiLU(), + gate, + ) + + def validate(self, charge: torch.Tensor, spin: torch.Tensor) -> None: + """Check that charge and spin values are within the supported range. + + Call this outside of ``torch.compile`` regions to get descriptive errors. + + :param charge: Per-system total charge, shape ``[n_systems]``. + :param spin: Per-system spin multiplicity, shape ``[n_systems]``. + """ + if (charge < -self.max_charge).any() or (charge > self.max_charge).any(): + raise ValueError( + f"charge values must be in [{-self.max_charge}, " + f"{self.max_charge}], got min={charge.min().item()}, " + f"max={charge.max().item()}. Increase max_charge in " + f"model hypers to support wider charge ranges." + ) + if (spin < 1).any() or (spin > self.max_spin).any(): + raise ValueError( + f"spin multiplicity values must be in [1, " + f"{self.max_spin}], got min={spin.min().item()}, " + f"max={spin.max().item()}. Increase max_spin in " + f"model hypers to support higher spin multiplicities." + ) + + def forward( + self, + charge: torch.Tensor, + spin: torch.Tensor, + system_indices: torch.Tensor, + ) -> torch.Tensor: + """Compute per-atom conditioning features from per-system charge and spin. + + :param charge: Per-system total charge, shape ``[n_systems]``, integer. + :param spin: Per-system spin multiplicity (2S+1), shape ``[n_systems]``, + integer >= 1. + :param system_indices: Maps each atom to its system index, + shape ``[n_atoms]``. + :return: Per-atom conditioning features, shape ``[n_atoms, d_out]``. + """ + c_emb = self.charge_embedding(charge + self.max_charge) # [n_systems, d_out] + s_emb = self.spin_embedding(spin - 1) # [n_systems, d_out] + system_emb = self.project(torch.cat([c_emb, s_emb], dim=-1)) # [n_systems, d] + return system_emb[system_indices] # [n_atoms, d_out] diff --git a/src/metatrain/pet/tests/checkpoints/model-v12_trainer-v12.ckpt.gz b/src/metatrain/pet/tests/checkpoints/model-v12_trainer-v12.ckpt.gz new file mode 100644 index 0000000000..d2a289c968 Binary files /dev/null and b/src/metatrain/pet/tests/checkpoints/model-v12_trainer-v12.ckpt.gz differ diff --git a/src/metatrain/pet/tests/test_conditioning.py b/src/metatrain/pet/tests/test_conditioning.py new file mode 100644 index 0000000000..daa825ba45 --- /dev/null +++ b/src/metatrain/pet/tests/test_conditioning.py @@ -0,0 +1,242 @@ +"""Tests for charge and spin conditioning embeddings in PET.""" + +import copy + +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatomic.torch import ModelOutput, System + +from metatrain.pet import PET +from metatrain.pet.modules.conditioning import SystemConditioningEmbedding +from metatrain.utils.architectures import get_default_hypers +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + + +def _small_hypers(system_conditioning=True, **kwargs): + """Return minimal PET hypers with system conditioning.""" + hypers = copy.deepcopy(get_default_hypers("pet")["model"]) + hypers["d_pet"] = 8 + hypers["d_head"] = 8 + hypers["d_node"] = 8 + hypers["d_feedforward"] = 8 + hypers["num_heads"] = 1 + hypers["num_attention_layers"] = 1 + hypers["num_gnn_layers"] = 1 + hypers["system_conditioning"] = system_conditioning + hypers.update(kwargs) + return hypers + + +def _dataset_info(): + return DatasetInfo( + length_unit="Angstrom", + atomic_types=[1, 6], + targets={ + "energy": get_energy_target_info( + "energy", {"quantity": "energy", "unit": "eV"} + ) + }, + ) + + +def _make_scalar_tmap(value: int, property_name: str = "value") -> TensorMap: + """Create a scalar TensorMap with a single float value (System requires float).""" + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[float(value)]]), + samples=Labels("system", torch.tensor([[0]])), + components=[], + properties=Labels(property_name, torch.tensor([[0]])), + ) + ], + ) + + +def _make_system(model, charge=None, spin=None): + """Create a simple 2-atom system with optional charge/spin.""" + system = System( + types=torch.tensor([6, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + if charge is not None: + system.add_data("mtt::charge", _make_scalar_tmap(charge, "charge")) + if spin is not None: + system.add_data("mtt::spin", _make_scalar_tmap(spin, "spin")) + return get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + +def test_conditioning_shapes(): + """SystemConditioningEmbedding produces [n_atoms, d_out].""" + d_out = 16 + module = SystemConditioningEmbedding(d_out=d_out, max_charge=5, max_spin=5) + + charge = torch.tensor([0, 2]) # 2 systems + spin = torch.tensor([1, 3]) # 2 systems + system_indices = torch.tensor([0, 0, 0, 1, 1]) # 5 atoms total + + out = module(charge, spin, system_indices) + assert out.shape == (5, d_out) + assert out.dtype == torch.float32 + + +def test_conditioning_different_d_node_d_pet(): + """Conditioning works when d_node != d_pet (the common case).""" + hypers = _small_hypers(d_pet=8, d_node=16) + model = PET(hypers, _dataset_info()) + model.eval() + + system = _make_system(model, charge=1, spin=2) + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result = model([system], outputs) + assert "energy" in result + + +def _train_steps(model, n_steps=10): + """Do a few optimizer steps to break the zero-init of the conditioning gate.""" + model.train() + for _ in range(n_steps): + system = _make_system(model, charge=2, spin=3) + outputs = {"energy": ModelOutput(per_atom=False)} + result = model([system], outputs) + loss = result["energy"].block().values.sum() + loss.backward() + with torch.no_grad(): + for p in model.parameters(): + if p.grad is not None: + p -= 0.01 * p.grad + p.grad.zero_() + model.eval() + + +def test_conditioning_changes_output(): + """Same structure with different charges should produce different predictions.""" + hypers = _small_hypers() + model = PET(hypers, _dataset_info()) + _train_steps(model) + + system_neutral = _make_system(model, charge=0, spin=1) + system_charged = _make_system(model, charge=2, spin=1) + + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result_neutral = model([system_neutral], outputs) + result_charged = model([system_charged], outputs) + + e_neutral = result_neutral["energy"].block().values + e_charged = result_charged["energy"].block().values + assert not torch.allclose(e_neutral, e_charged), ( + "Different charges should produce different energies" + ) + + +def test_conditioning_disabled_unchanged(): + """With system_conditioning=False, no conditioning module should exist.""" + hypers_off = _small_hypers(system_conditioning=False) + model = PET(hypers_off, _dataset_info()) + + assert model.system_conditioning is None + + # Model should still run fine + model.eval() + system = _make_system(model) + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result = model([system], outputs) + assert "energy" in result + + +def test_conditioning_gradients_flow(): + """Gradients should flow through the conditioning embeddings.""" + module = SystemConditioningEmbedding(d_out=8, max_charge=5, max_spin=5) + + charge = torch.tensor([1]) + spin = torch.tensor([2]) + system_indices = torch.tensor([0, 0]) + + out = module(charge, spin, system_indices) + loss = out.sum() + loss.backward() + + assert module.charge_embedding.weight.grad is not None + assert module.spin_embedding.weight.grad is not None + assert module.project[0].weight.grad is not None + + +def test_conditioning_batch_independence(): + """Changing charge of one system in a batch should not affect others.""" + hypers = _small_hypers() + model = PET(hypers, _dataset_info()) + _train_steps(model) + + system_a = _make_system(model, charge=0, spin=1) + system_b_v1 = _make_system(model, charge=1, spin=1) + system_b_v2 = _make_system(model, charge=3, spin=2) + + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result_v1 = model([system_a, system_b_v1], outputs) + result_v2 = model([system_a, system_b_v2], outputs) + + # Energy of system_a should be the same in both batches + e_a_v1 = result_v1["energy"].block().values[0] + e_a_v2 = result_v2["energy"].block().values[0] + torch.testing.assert_close(e_a_v1, e_a_v2) + + # Energy of system_b should differ between batches + e_b_v1 = result_v1["energy"].block().values[1] + e_b_v2 = result_v2["energy"].block().values[1] + assert not torch.allclose(e_b_v1, e_b_v2) + + +def test_conditioning_default_values(): + """Systems without explicit charge/spin should use defaults (charge=0, spin=1).""" + hypers = _small_hypers() + model = PET(hypers, _dataset_info()) + model.eval() + + # System with no charge/spin data (should default to charge=0, spin=1) + system_default = _make_system(model) + # System with explicit charge=0, spin=1 + system_explicit = _make_system(model, charge=0, spin=1) + + outputs = {"energy": ModelOutput(per_atom=False)} + with torch.no_grad(): + result_default = model([system_default], outputs) + result_explicit = model([system_explicit], outputs) + + e_default = result_default["energy"].block().values + e_explicit = result_explicit["energy"].block().values + torch.testing.assert_close(e_default, e_explicit) + + +def test_conditioning_out_of_range(): + """Charges or spins outside the supported range raise ValueError.""" + module = SystemConditioningEmbedding(d_out=8, max_charge=3, max_spin=4) + + # charge too positive + with pytest.raises(ValueError, match=r"charge values must be in \[-3, 3\]"): + module.validate(torch.tensor([5]), torch.tensor([1])) + + # charge too negative + with pytest.raises(ValueError, match=r"charge values must be in \[-3, 3\]"): + module.validate(torch.tensor([-4]), torch.tensor([1])) + + # spin too high + with pytest.raises( + ValueError, match=r"spin multiplicity values must be in \[1, 4\]" + ): + module.validate(torch.tensor([0]), torch.tensor([5])) + + # spin too low (0 is invalid, minimum is 1) + with pytest.raises( + ValueError, match=r"spin multiplicity values must be in \[1, 4\]" + ): + module.validate(torch.tensor([0]), torch.tensor([0])) diff --git a/src/metatrain/share/base_hypers.py b/src/metatrain/share/base_hypers.py index 474ab8bcb6..e6abb3c302 100644 --- a/src/metatrain/share/base_hypers.py +++ b/src/metatrain/share/base_hypers.py @@ -31,6 +31,14 @@ class ArchitectureBaseHypers(TypedDict): """ +@with_config(ConfigDict(extra="forbid", strict=True)) +class SystemDataKeyHypers(TypedDict): + """Reference to a per-system scalar stored in a memmap ``.bin`` file.""" + + key: str + """Filename stem of the ``.bin`` file (e.g. ``q`` reads ``q.bin``).""" + + @with_config(ConfigDict(extra="forbid", strict=True)) class SystemsHypers(TypedDict): """Hyperparameters for the systems in the dataset.""" @@ -51,6 +59,12 @@ class SystemsHypers(TypedDict): The list of possible length units is available `here `_.""" + charge: NotRequired[SystemDataKeyHypers] + """Per-system total charge stored in a memmap ``.bin`` file. Only used + with memmap datasets and PET's ``system_conditioning`` feature.""" + spin: NotRequired[SystemDataKeyHypers] + """Per-system spin multiplicity (2S+1) stored in a memmap ``.bin`` file. + Only used with memmap datasets and PET's ``system_conditioning`` feature.""" @with_config(ConfigDict(extra="forbid", strict=True)) diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 9c157f52fd..fc44af241e 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -979,9 +979,19 @@ class MemmapDataset(TorchDataset): :param path: Path to the directory containing the dataset. :param target_options: Dictionary containing the target configurations, in the format corresponding to metatrain yaml input files. + :param system_options: Optional dictionary with system-level options. Supported + keys are ``charge`` (with sub-key ``key`` specifying the ``.bin`` filename + stem) and ``spin`` (same format). These are loaded as per-system scalars and + attached to each ``System`` via ``add_data("mtt::charge", ...)`` / + ``add_data("mtt::spin", ...)``. """ - def __init__(self, path: Union[str, Path], target_options: Dict[str, Any]) -> None: + def __init__( + self, + path: Union[str, Path], + target_options: Dict[str, Any], + system_options: Optional[Dict[str, Any]] = None, + ) -> None: path = Path(path) self.target_config = target_options self.sample_class = namedtuple( @@ -999,6 +1009,21 @@ def __init__(self, path: Union[str, Path], target_options: Dict[str, Any]) -> No path / "momenta.bin", (self.na[-1], 3), "float32", mode="r" ) + # Optional per-system charge and spin arrays + self.charge_array: Optional[MemmapArray] = None + self.spin_array: Optional[MemmapArray] = None + if system_options is not None: + if "charge" in system_options: + charge_key = system_options["charge"]["key"] + self.charge_array = MemmapArray( + path / f"{charge_key}.bin", (self.ns,), "float32", mode="r" + ) + if "spin" in system_options: + spin_key = system_options["spin"]["key"] + self.spin_array = MemmapArray( + path / f"{spin_key}.bin", (self.ns,), "float32", mode="r" + ) + # Register arrays pointing to the targets self.target_arrays = {} for target_key, single_target_options in target_options.items(): @@ -1075,6 +1100,48 @@ def __getitem__(self, i: int) -> Any: pbc=torch.logical_not(torch.all(c == 0.0, dim=1)), ) + # Attach optional per-system charge and spin data + if self.charge_array is not None: + system.add_data( + "mtt::charge", + TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor( + [[float(self.charge_array[i])]], + dtype=torch.float64, + ), + samples=Labels( + "system", torch.tensor([[i]], dtype=torch.int32) + ), + components=[], + properties=Labels("charge", torch.tensor([[0]])), + ) + ], + ), + ) + if self.spin_array is not None: + system.add_data( + "mtt::spin", + TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor( + [[float(self.spin_array[i])]], + dtype=torch.float64, + ), + samples=Labels( + "system", torch.tensor([[i]], dtype=torch.int32) + ), + components=[], + properties=Labels("spin", torch.tensor([[0]])), + ) + ], + ), + ) + target_dict = {} for target_key, target_options in self.target_config.items(): target_array = self.target_arrays[target_key] diff --git a/src/metatrain/utils/data/get_dataset.py b/src/metatrain/utils/data/get_dataset.py index 16a964a1d5..6d94877628 100644 --- a/src/metatrain/utils/data/get_dataset.py +++ b/src/metatrain/utils/data/get_dataset.py @@ -38,7 +38,11 @@ def get_dataset( if "extra_data" in options: extra_data_info_dictionary = dataset.get_target_info(options["extra_data"]) elif Path(options["systems"]["read_from"]).is_dir(): # memmap dataset - dataset = MemmapDataset(options["systems"]["read_from"], options["targets"]) + dataset = MemmapDataset( + options["systems"]["read_from"], + options["targets"], + system_options=options["systems"], + ) target_info_dictionary = dataset.get_target_info() else: systems = read_systems(