From b1fada1d8a413eea1c6b195537cdb2e494ff8cff Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Fri, 13 Mar 2026 16:23:30 +0100 Subject: [PATCH] feat(torchsim): add variants support for model outputs Add variants parameter to MetatomicModel.__init__, following the same pattern as ase_calculator. Uses pick_output to resolve the energy output variant from the model's capabilities. This allows using models with multiple output variants (e.g. different functionals or uncertainty quantification outputs) by specifying which variant to use for energy evaluation. --- .../metatomic_torchsim/_model.py | 14 +++++++++++--- .../metatomic_torchsim/tests/test_model_loading.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 09e7026db..00a00aba1 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -24,6 +24,7 @@ System, load_atomistic_model, pick_device, + pick_output, ) @@ -75,6 +76,7 @@ def __init__( check_consistency: bool = False, compute_forces: bool = True, compute_stress: bool = True, + variants: Optional[Dict[str, Optional[str]]] = None, ) -> None: """Initialize the metatomic model wrapper. @@ -91,6 +93,9 @@ def __init__( Useful for debugging but hurts performance. :param compute_forces: Compute atomic forces via autograd. :param compute_stress: Compute stress tensors via the strain trick. + :param variants: Dictionary mapping output names to variant names. If not + provided, the default variant is used for all outputs. See + :py:func:`metatomic.torch.pick_output` for details on variant selection. """ super().__init__() @@ -149,6 +154,10 @@ def __init__( "Only models with energy outputs can be used with TorchSim." ) + # Resolve output variants + variants = variants or {} + self._energy_key = pick_output("energy", capabilities.outputs, variants.get("energy")) + self._model = model.to(device=self._device) self._compute_forces = compute_forces self._compute_stress = compute_stress @@ -158,7 +167,7 @@ def __init__( self._evaluation_options = ModelEvaluationOptions( length_unit="angstrom", outputs={ - "energy": ModelOutput(quantity="energy", unit="eV", per_atom=False) + self._energy_key: ModelOutput(quantity="energy", unit="eV", per_atom=False) }, ) @@ -243,8 +252,7 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: check_consistency=self._check_consistency, ) - energy_values = model_outputs["energy"].block().values - + energy_values = model_outputs[self._energy_key].block().values results: Dict[str, torch.Tensor] = {} results["energy"] = energy_values.detach().squeeze(-1) diff --git a/python/metatomic_torchsim/tests/test_model_loading.py b/python/metatomic_torchsim/tests/test_model_loading.py index 2dc5ef687..237fc03f9 100644 --- a/python/metatomic_torchsim/tests/test_model_loading.py +++ b/python/metatomic_torchsim/tests/test_model_loading.py @@ -58,3 +58,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dummy_scripted = torch.jit.script(Dummy()) with pytest.raises(TypeError, match="must be 'AtomisticModel'"): MetatomicModel(model=dummy_scripted, device=DEVICE) + + + +def test_variants_parameter_accepted(lj_model): + """Variants parameter is accepted even for models without variants.""" + # The LJ test model has no variants, but the parameter should be accepted + model = MetatomicModel(model=lj_model, device=DEVICE, variants=None) + assert model._energy_key == "energy" + + # Explicit empty variants dict should also work + model = MetatomicModel(model=lj_model, device=DEVICE, variants={}) + assert model._energy_key == "energy" \ No newline at end of file