Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ from upet.calculator import UPETCalculator
from ase.build import bulk

atoms = bulk("Si", cubic=True, a=5.43, crystalstructure="diamond")
calculator = UPETCalculator(model="pet-mad-s", version="1.0.2", device="cpu", calculate_uncertainty=True, calculate_ensemble=True)
calculator = UPETCalculator(model="pet-mad-s", version="1.0.2", device="cpu")
atoms.calc = calculator
energy = atoms.get_potential_energy()

Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,8 @@ filterwarnings = [
# deprecation warnings from torch_geometric
"ignore:`torch_geometric.distributed` has been deprecated since 2.7.0 and will no longer be maintained",
# metatrain checkpoint upgrade warning
"ignore: trying to upgrade an old model checkpoint with unknown version, this might fail and require manual modifications"
"ignore: trying to upgrade an old model checkpoint with unknown version, this might fail and require manual modifications",
# TorchScript deprecation warnings
"ignore: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will hide the warnings from CI, but users would still get them. Are you fine with this?

"ignore: `torch.jit.load` is deprecated. Please switch to `torch.export`.",
]
11 changes: 9 additions & 2 deletions src/upet/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,14 @@ def get_upet(

# Generate metadata based on available info
metadata = get_upet_metadata(model=model, size=size, version=str(version))
return loaded_model.export(metadata)
exported_model = loaded_model.export(metadata)

# TorchScript the model
for parameter in exported_model.parameters():
parameter.requires_grad = False
exported_model = exported_model.eval()
exported_model = torch.jit.script(exported_model)
return exported_model


def save_upet(
Expand Down Expand Up @@ -236,7 +243,7 @@ def save_upet(
else:
output = "model.pt"

loaded_model.save(output)
torch.jit.save(loaded_model.to("cpu"), output)
logging.info(f"Saved UPET model to {output}")


Expand Down
39 changes: 9 additions & 30 deletions src/upet/calculator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
import warnings
from typing import Dict, List, Optional, Tuple, Union

import ase.calculators.calculator
Expand All @@ -9,7 +8,6 @@
from metatomic.torch import ModelOutput
from metatomic.torch.ase_calculator import MetatomicCalculator, SymmetrizedCalculator
from packaging.version import Version
from platformdirs import user_cache_dir

from ._models import (
_get_bandgap_model,
Expand Down Expand Up @@ -123,22 +121,6 @@ def __init__(
# Branch 1: Loading from a local checkpoint
if checkpoint_path is not None:
model_name, size, version = parse_checkpoint_filename(checkpoint_path)

loaded_model = get_upet(
model=model_name,
size=size,
version=version,
checkpoint_path=checkpoint_path,
)

# Determine cache name
if model_name and size and version:
if not isinstance(version, Version):
version = Version(version)
cache_name = f"{model_name}-{size}-v{version}"
else:
cache_name = os.path.split(checkpoint_path)[-1].replace(".ckpt", "")

# Branch 2: Loading from HuggingFace
else:
if model is None:
Expand All @@ -159,12 +141,16 @@ def __init__(
requested_version=version if version != "latest" else None,
)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=UserWarning)

loaded_model = get_upet(
model=model_name,
size=size,
version=version,
checkpoint_path=checkpoint_path,
)
cache_name = f"{model_name}-{size}-v{version}"

model_outputs = loaded_model.capabilities().outputs
if non_conservative:
Expand All @@ -175,8 +161,8 @@ def __init__(
if nc_forces_key not in model_outputs or nc_stress_key not in model_outputs:
raise NotImplementedError(
"Non-conservative forces and stresses are not available for the "
f"model {cache_name}. Please run without non_conservative=True, "
"or choose another model."
f"model {model}, v{version}. Please run without "
"non_conservative=True, or choose another model."
)

if dtype is not None:
Expand All @@ -186,15 +172,8 @@ def __init__(
loaded_model._capabilities.dtype = DTYPE_TO_STR[dtype]
loaded_model = loaded_model.to(dtype=dtype, device=device)

cache_dir = user_cache_dir("upet", "metatensor")
os.makedirs(cache_dir, exist_ok=True)

pt_path = os.path.join(cache_dir, f"{cache_name}.pt")
logging.info(f"Exporting checkpoint to TorchScript at {pt_path}")
loaded_model.save(pt_path, collect_extensions=None)

self.calculator = MetatomicCalculator(
pt_path,
loaded_model,
extensions_directory=None,
check_consistency=check_consistency,
device=device,
Expand Down