diff --git a/README.md b/README.md index 0462690..1046880 100644 --- a/README.md +++ b/README.md @@ -385,7 +385,7 @@ from upet.calculator import UPETCalculator from ase.build import bulk atoms = bulk("Si", cubic=True, a=5.43, crystalstructure="diamond") -calculator = UPETCalculator(model="pet-mad-s", version="1.0.2", device="cpu", calculate_uncertainty=True, calculate_ensemble=True) +calculator = UPETCalculator(model="pet-mad-s", version="1.0.2", device="cpu") atoms.calc = calculator energy = atoms.get_potential_energy() diff --git a/pyproject.toml b/pyproject.toml index 6993173..984c93c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,5 +114,8 @@ filterwarnings = [ # deprecation warnings from torch_geometric "ignore:`torch_geometric.distributed` has been deprecated since 2.7.0 and will no longer be maintained", # metatrain checkpoint upgrade warning - "ignore: trying to upgrade an old model checkpoint with unknown version, this might fail and require manual modifications" + "ignore: trying to upgrade an old model checkpoint with unknown version, this might fail and require manual modifications", + # TorchScript deprecation warnings + "ignore: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.", + "ignore: `torch.jit.load` is deprecated. Please switch to `torch.export`.", ] diff --git a/src/upet/_models.py b/src/upet/_models.py index 193e280..78759f9 100644 --- a/src/upet/_models.py +++ b/src/upet/_models.py @@ -198,7 +198,14 @@ def get_upet( # Generate metadata based on available info metadata = get_upet_metadata(model=model, size=size, version=str(version)) - return loaded_model.export(metadata) + exported_model = loaded_model.export(metadata) + + # TorchScript the model + for parameter in exported_model.parameters(): + parameter.requires_grad = False + exported_model = exported_model.eval() + exported_model = torch.jit.script(exported_model) + return exported_model def save_upet( @@ -236,7 +243,7 @@ def save_upet( else: output = "model.pt" - loaded_model.save(output) + torch.jit.save(loaded_model.to("cpu"), output) logging.info(f"Saved UPET model to {output}") diff --git a/src/upet/calculator.py b/src/upet/calculator.py index 8de7e93..a88f0a3 100644 --- a/src/upet/calculator.py +++ b/src/upet/calculator.py @@ -1,5 +1,4 @@ -import logging -import os +import warnings from typing import Dict, List, Optional, Tuple, Union import ase.calculators.calculator @@ -9,7 +8,6 @@ from metatomic.torch import ModelOutput from metatomic.torch.ase_calculator import MetatomicCalculator, SymmetrizedCalculator from packaging.version import Version -from platformdirs import user_cache_dir from ._models import ( _get_bandgap_model, @@ -123,22 +121,6 @@ def __init__( # Branch 1: Loading from a local checkpoint if checkpoint_path is not None: model_name, size, version = parse_checkpoint_filename(checkpoint_path) - - loaded_model = get_upet( - model=model_name, - size=size, - version=version, - checkpoint_path=checkpoint_path, - ) - - # Determine cache name - if model_name and size and version: - if not isinstance(version, Version): - version = Version(version) - cache_name = f"{model_name}-{size}-v{version}" - else: - cache_name = os.path.split(checkpoint_path)[-1].replace(".ckpt", "") - # Branch 2: Loading from HuggingFace else: if model is None: @@ -159,12 +141,16 @@ def __init__( requested_version=version if version != "latest" else None, ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + warnings.simplefilter("ignore", category=UserWarning) + loaded_model = get_upet( model=model_name, size=size, version=version, + checkpoint_path=checkpoint_path, ) - cache_name = f"{model_name}-{size}-v{version}" model_outputs = loaded_model.capabilities().outputs if non_conservative: @@ -175,8 +161,8 @@ def __init__( if nc_forces_key not in model_outputs or nc_stress_key not in model_outputs: raise NotImplementedError( "Non-conservative forces and stresses are not available for the " - f"model {cache_name}. Please run without non_conservative=True, " - "or choose another model." + f"model {model}, v{version}. Please run without " + "non_conservative=True, or choose another model." ) if dtype is not None: @@ -186,15 +172,8 @@ def __init__( loaded_model._capabilities.dtype = DTYPE_TO_STR[dtype] loaded_model = loaded_model.to(dtype=dtype, device=device) - cache_dir = user_cache_dir("upet", "metatensor") - os.makedirs(cache_dir, exist_ok=True) - - pt_path = os.path.join(cache_dir, f"{cache_name}.pt") - logging.info(f"Exporting checkpoint to TorchScript at {pt_path}") - loaded_model.save(pt_path, collect_extensions=None) - self.calculator = MetatomicCalculator( - pt_path, + loaded_model, extensions_directory=None, check_consistency=check_consistency, device=device,