From 8998f9ae1579f0f858ab293957d6e1fd400c0e1a Mon Sep 17 00:00:00 2001 From: Arslan Mazitov Date: Wed, 4 Feb 2026 21:21:35 +0100 Subject: [PATCH 1/6] Torch-scripting the model without saving --- src/upet/_models.py | 11 +++++++++-- src/upet/calculator.py | 25 +++---------------------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/upet/_models.py b/src/upet/_models.py index 193e280..78759f9 100644 --- a/src/upet/_models.py +++ b/src/upet/_models.py @@ -198,7 +198,14 @@ def get_upet( # Generate metadata based on available info metadata = get_upet_metadata(model=model, size=size, version=str(version)) - return loaded_model.export(metadata) + exported_model = loaded_model.export(metadata) + + # TorchScript the model + for parameter in exported_model.parameters(): + parameter.requires_grad = False + exported_model = exported_model.eval() + exported_model = torch.jit.script(exported_model) + return exported_model def save_upet( @@ -236,7 +243,7 @@ def save_upet( else: output = "model.pt" - loaded_model.save(output) + torch.jit.save(loaded_model.to("cpu"), output) logging.info(f"Saved UPET model to {output}") diff --git a/src/upet/calculator.py b/src/upet/calculator.py index 8de7e93..6cfa0b3 100644 --- a/src/upet/calculator.py +++ b/src/upet/calculator.py @@ -1,5 +1,3 @@ -import logging -import os from typing import Dict, List, Optional, Tuple, Union import ase.calculators.calculator @@ -9,7 +7,6 @@ from metatomic.torch import ModelOutput from metatomic.torch.ase_calculator import MetatomicCalculator, SymmetrizedCalculator from packaging.version import Version -from platformdirs import user_cache_dir from ._models import ( _get_bandgap_model, @@ -131,14 +128,6 @@ def __init__( checkpoint_path=checkpoint_path, ) - # Determine cache name - if model_name and size and version: - if not isinstance(version, Version): - version = Version(version) - cache_name = f"{model_name}-{size}-v{version}" - else: - cache_name = os.path.split(checkpoint_path)[-1].replace(".ckpt", "") - # Branch 2: Loading from HuggingFace else: if model is None: @@ -164,7 +153,6 @@ def __init__( size=size, version=version, ) - cache_name = f"{model_name}-{size}-v{version}" model_outputs = loaded_model.capabilities().outputs if non_conservative: @@ -175,8 +163,8 @@ def __init__( if nc_forces_key not in model_outputs or nc_stress_key not in model_outputs: raise NotImplementedError( "Non-conservative forces and stresses are not available for the " - f"model {cache_name}. Please run without non_conservative=True, " - "or choose another model." + f"model {model}, v{version}. Please run without " + "non_conservative=True, or choose another model." ) if dtype is not None: @@ -186,15 +174,8 @@ def __init__( loaded_model._capabilities.dtype = DTYPE_TO_STR[dtype] loaded_model = loaded_model.to(dtype=dtype, device=device) - cache_dir = user_cache_dir("upet", "metatensor") - os.makedirs(cache_dir, exist_ok=True) - - pt_path = os.path.join(cache_dir, f"{cache_name}.pt") - logging.info(f"Exporting checkpoint to TorchScript at {pt_path}") - loaded_model.save(pt_path, collect_extensions=None) - self.calculator = MetatomicCalculator( - pt_path, + loaded_model, extensions_directory=None, check_consistency=check_consistency, device=device, From 43fe69ee5fcd18aa4ae33572ce0a89e69f75e4e8 Mon Sep 17 00:00:00 2001 From: Arslan Mazitov Date: Wed, 4 Feb 2026 21:21:56 +0100 Subject: [PATCH 2/6] Removed unused args from the UQ docs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0462690..1046880 100644 --- a/README.md +++ b/README.md @@ -385,7 +385,7 @@ from upet.calculator import UPETCalculator from ase.build import bulk atoms = bulk("Si", cubic=True, a=5.43, crystalstructure="diamond") -calculator = UPETCalculator(model="pet-mad-s", version="1.0.2", device="cpu", calculate_uncertainty=True, calculate_ensemble=True) +calculator = UPETCalculator(model="pet-mad-s", version="1.0.2", device="cpu") atoms.calc = calculator energy = atoms.get_potential_energy() From f057ac40834ca2ede6ce512a3d64690565ad44d6 Mon Sep 17 00:00:00 2001 From: Arslan Mazitov Date: Wed, 4 Feb 2026 21:31:49 +0100 Subject: [PATCH 3/6] Restricted the pytorch version to avoid torch-script deprecation warnings --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6993173..5ed3b37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ maintainers = [ ] dependencies = [ + "torch<2.10.0", "metatrain==2026.1", "huggingface_hub", "hf_xet", From e8e0a8dedd50423bd55576150cdfda7f2b50b34c Mon Sep 17 00:00:00 2001 From: Arslan Mazitov Date: Thu, 5 Feb 2026 13:08:08 +0100 Subject: [PATCH 4/6] Reverted the torch requirements and silenced torch-script deprecation warning --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5ed3b37..738b5d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ maintainers = [ ] dependencies = [ - "torch<2.10.0", "metatrain==2026.1", "huggingface_hub", "hf_xet", @@ -115,5 +114,7 @@ filterwarnings = [ # deprecation warnings from torch_geometric "ignore:`torch_geometric.distributed` has been deprecated since 2.7.0 and will no longer be maintained", # metatrain checkpoint upgrade warning - "ignore: trying to upgrade an old model checkpoint with unknown version, this might fail and require manual modifications" + "ignore: trying to upgrade an old model checkpoint with unknown version, this might fail and require manual modifications", + # TorchScript deprecation warning + "ignore: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`." ] From 0d415b5f93db869c9a4f9d651f05a32298c4c1ac Mon Sep 17 00:00:00 2001 From: Arslan Mazitov Date: Thu, 5 Feb 2026 13:15:38 +0100 Subject: [PATCH 5/6] Silenced jit.load warning --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 738b5d7..984c93c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ filterwarnings = [ "ignore:`torch_geometric.distributed` has been deprecated since 2.7.0 and will no longer be maintained", # metatrain checkpoint upgrade warning "ignore: trying to upgrade an old model checkpoint with unknown version, this might fail and require manual modifications", - # TorchScript deprecation warning - "ignore: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`." + # TorchScript deprecation warnings + "ignore: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.", + "ignore: `torch.jit.load` is deprecated. Please switch to `torch.export`.", ] From f74032f2e96c5d8d08723e8173099e939f33ba86 Mon Sep 17 00:00:00 2001 From: Arslan Mazitov Date: Thu, 5 Feb 2026 14:53:01 +0100 Subject: [PATCH 6/6] Suppressing the warnings --- src/upet/calculator.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/upet/calculator.py b/src/upet/calculator.py index 6cfa0b3..a88f0a3 100644 --- a/src/upet/calculator.py +++ b/src/upet/calculator.py @@ -1,3 +1,4 @@ +import warnings from typing import Dict, List, Optional, Tuple, Union import ase.calculators.calculator @@ -120,14 +121,6 @@ def __init__( # Branch 1: Loading from a local checkpoint if checkpoint_path is not None: model_name, size, version = parse_checkpoint_filename(checkpoint_path) - - loaded_model = get_upet( - model=model_name, - size=size, - version=version, - checkpoint_path=checkpoint_path, - ) - # Branch 2: Loading from HuggingFace else: if model is None: @@ -148,10 +141,15 @@ def __init__( requested_version=version if version != "latest" else None, ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + warnings.simplefilter("ignore", category=UserWarning) + loaded_model = get_upet( model=model_name, size=size, version=version, + checkpoint_path=checkpoint_path, ) model_outputs = loaded_model.capabilities().outputs