Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ outputs/
examples/basic_usage/*.xyz
extensions/

# lock files for model training
tests/resources/*.trainlock

# sphinx gallery
docs/src/generated_examples/
*execution_times*
Expand Down
Empty file added tests/__init__.py
Empty file.
21 changes: 0 additions & 21 deletions tests/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +0,0 @@
from pathlib import Path

from metatrain.utils.architectures import get_default_hypers


MODEL_HYPERS = get_default_hypers("soap_bpnn")["model"]

RESOURCES_PATH = Path(__file__).parents[1] / "resources"

DATASET_PATH_QM9 = RESOURCES_PATH / "qm9_reduced_100.xyz"
DATASET_PATH_ETHANOL = RESOURCES_PATH / "ethanol_reduced_100.xyz"
DATASET_PATH_CARBON = RESOURCES_PATH / "carbon_reduced_100.xyz"
DATASET_PATH_QM7X = RESOURCES_PATH / "qm7x_reduced_100.xyz"
DATASET_PATH_DOS = RESOURCES_PATH / "dos_100.xyz"
EVAL_OPTIONS_PATH = RESOURCES_PATH / "eval.yaml"
MODEL_PATH = RESOURCES_PATH / "model-32-bit.pt"
MODEL_PATH_64_BIT = RESOURCES_PATH / "model-64-bit.ckpt"
MODEL_PATH_PET = RESOURCES_PATH / "model-pet.ckpt"
OPTIONS_PATH = RESOURCES_PATH / "options.yaml"
OPTIONS_PET_PATH = RESOURCES_PATH / "options-pet.yaml"
OPTIONS_EXTRA_DATA_PATH = RESOURCES_PATH / "options-extra-data.yaml"
36 changes: 25 additions & 11 deletions tests/cli/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from metatrain.utils.data.writers import DiskDatasetWriter
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists

from . import EVAL_OPTIONS_PATH, MODEL_HYPERS, MODEL_PATH, RESOURCES_PATH
from ..conftest import EVAL_OPTIONS_PATH, MODEL_HYPERS, RESOURCES_PATH


@pytest.fixture
def model():
def model(MODEL_PATH):
return torch.jit.load(MODEL_PATH)


Expand All @@ -31,7 +31,7 @@ def options():
return OmegaConf.load(EVAL_OPTIONS_PATH)


def test_eval_cli(monkeypatch, tmp_path):
def test_eval_cli(monkeypatch, tmp_path, MODEL_PATH):
"""Test succesful run of the eval script via the CLI with default arguments"""
monkeypatch.chdir(tmp_path)
shutil.copy(RESOURCES_PATH / "qm9_reduced_100.xyz", "qm9_reduced_100.xyz")
Expand All @@ -54,15 +54,22 @@ def test_eval_cli(monkeypatch, tmp_path):
assert Path("output.xyz").is_file()


@pytest.mark.parametrize("model_name", ["model-32-bit.pt", "model-64-bit.pt"])
def test_eval(monkeypatch, tmp_path, caplog, model_name, options):
@pytest.mark.parametrize("model_type", ["32-bit", "64-bit"])
def test_eval(request, monkeypatch, tmp_path, caplog, model_type, options):
"""Test that eval via python API runs without an error raise."""
monkeypatch.chdir(tmp_path)
caplog.set_level(logging.INFO)

shutil.copy(RESOURCES_PATH / "qm9_reduced_100.xyz", "qm9_reduced_100.xyz")

model = torch.jit.load(RESOURCES_PATH / model_name)
fixture_name = {
"32-bit": "MODEL_PATH",
"64-bit": "MODEL_PATH_64_BIT",
}.get(model_type)

model_path = request.getfixturevalue(fixture_name)

model = torch.jit.load(model_path)

eval_model(
model=model,
Expand All @@ -84,15 +91,22 @@ def test_eval(monkeypatch, tmp_path, caplog, model_name, options):
frames[0].info["energy"]


@pytest.mark.parametrize("model_name", ["model-32-bit.pt", "model-64-bit.pt"])
def test_eval_batch_size(monkeypatch, tmp_path, caplog, model_name, options):
@pytest.mark.parametrize("model_type", ["32-bit", "64-bit"])
def test_eval_batch_size(request, monkeypatch, tmp_path, caplog, model_type, options):
"""Test that eval via python API runs without an error raise."""
monkeypatch.chdir(tmp_path)
caplog.set_level(logging.DEBUG)

shutil.copy(RESOURCES_PATH / "qm9_reduced_100.xyz", "qm9_reduced_100.xyz")

model = torch.jit.load(RESOURCES_PATH / model_name)
fixture_name = {
"32-bit": "MODEL_PATH",
"64-bit": "MODEL_PATH_64_BIT",
}.get(model_type)

model_path = request.getfixturevalue(fixture_name)

model = torch.jit.load(model_path)

eval_model(
model=model,
Expand Down Expand Up @@ -180,14 +194,14 @@ def test_eval_no_targets(monkeypatch, tmp_path, model, options):


@pytest.mark.parametrize("suffix", [".zip", ".mts"])
def test_eval_disk_dataset(monkeypatch, tmp_path, caplog, suffix):
def test_eval_disk_dataset(monkeypatch, tmp_path, caplog, suffix, MODEL_PATH):
"""Test that eval via python API runs without an error raise."""
monkeypatch.chdir(tmp_path)
caplog.set_level(logging.INFO)

shutil.copy(RESOURCES_PATH / "qm9_reduced_100.xyz", "qm9_reduced_100.xyz")

model = torch.jit.load(RESOURCES_PATH / "model-32-bit.pt")
model = torch.jit.load(MODEL_PATH)

options = OmegaConf.create(
{
Expand Down
40 changes: 23 additions & 17 deletions tests/cli/test_export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
from metatrain.utils.architectures import find_all_architectures
from metatrain.utils.io import load_model

from . import RESOURCES_PATH


@pytest.mark.parametrize("path", [Path("exported.pt"), "exported.pt"])
def test_export(monkeypatch, tmp_path, path, caplog):
def test_export(monkeypatch, tmp_path, path, caplog, MODEL_PATH_64_BIT):
"""Tests the export_model function."""
monkeypatch.chdir(tmp_path)
caplog.set_level(logging.INFO)

checkpoint_path = RESOURCES_PATH / "model-64-bit.ckpt"
checkpoint_path = MODEL_PATH_64_BIT.with_suffix(".ckpt")
export_model(checkpoint_path, path)

# Test if extensions are saved
Expand All @@ -43,20 +41,28 @@ def test_export(monkeypatch, tmp_path, path, caplog):

@pytest.mark.parametrize("output", [None, "exported.pt"])
@pytest.mark.parametrize("model_type", ["32-bit", "64-bit", "pet"])
def test_export_cli(monkeypatch, tmp_path, output, model_type):
def test_export_cli(request, monkeypatch, tmp_path, output, model_type):
"""Test that the export cli runs without an error raise."""
monkeypatch.chdir(tmp_path)

fixture_name = {
"32-bit": "MODEL_PATH",
"64-bit": "MODEL_PATH_64_BIT",
"pet": "MODEL_PATH_PET",
}.get(model_type)

model_path = request.getfixturevalue(fixture_name)

command = [
"mtt",
"export",
str(RESOURCES_PATH / f"model-{model_type}.ckpt"),
str(model_path.with_suffix(".ckpt")),
]

if output is not None:
command += ["-o", output]
else:
output = f"model-{model_type}.pt"
output = model_path.name

subprocess.check_call(command)
assert Path(output).is_file()
Expand All @@ -83,21 +89,21 @@ def test_export_cli(monkeypatch, tmp_path, output, model_type):
assert next(model.parameters()).device.type == "cpu"


def test_export_with_env(monkeypatch, tmp_path):
def test_export_with_env(monkeypatch, tmp_path, MODEL_PATH):
"""Test that export with env variable works for local file."""
monkeypatch.chdir(tmp_path)

command = [
"mtt",
"export",
str(RESOURCES_PATH / "model-32-bit.ckpt"),
str(MODEL_PATH.with_suffix(".ckpt")),
]

env = os.environ.copy()
env["HF_TOKEN"] = "1234"

subprocess.check_call(command, env=env)
assert Path("model-32-bit.pt").is_file()
assert Path(MODEL_PATH.name).is_file()


def test_export_cli_unknown_architecture(tmpdir):
Expand All @@ -113,11 +119,11 @@ def test_export_cli_unknown_architecture(tmpdir):
assert architecture_name in stdout


def test_reexport(monkeypatch, tmp_path):
def test_reexport(monkeypatch, tmp_path, MODEL_PATH_64_BIT):
"""Test that an already exported model can be loaded and again exported."""
monkeypatch.chdir(tmp_path)

checkpoint_path = RESOURCES_PATH / "model-64-bit.ckpt"
checkpoint_path = MODEL_PATH_64_BIT.with_suffix(".ckpt")
export_model(checkpoint_path, "exported.pt")
export_model("exported.pt", "exported_new.pt")

Expand Down Expand Up @@ -200,7 +206,7 @@ def test_token_env_error():
subprocess.check_call(command, env=env)


def test_metadata(monkeypatch, tmp_path):
def test_metadata(monkeypatch, tmp_path, MODEL_PATH):
"""Test that the export cli does inject metadata."""
monkeypatch.chdir(tmp_path)

Expand All @@ -211,17 +217,17 @@ def test_metadata(monkeypatch, tmp_path):
command = [
"mtt",
"export",
str(RESOURCES_PATH / "model-32-bit.ckpt"),
str(MODEL_PATH.with_suffix(".ckpt")),
"--metadata=metadata.yaml",
]

subprocess.check_call(command)
model = load_model("model-32-bit.pt", extensions_directory="extensions/")
model = load_model(MODEL_PATH.name, extensions_directory="extensions/")

assert f"This is the {model_name} model" in str(model.metadata())


def test_export_checkpoint_with_metadata(monkeypatch, tmp_path):
def test_export_checkpoint_with_metadata(monkeypatch, tmp_path, MODEL_PATH):
"""Tests that the metadata is correctly assigned to the exported
model if the checkpoint has the metadata inside."""

Expand All @@ -234,7 +240,7 @@ def test_export_checkpoint_with_metadata(monkeypatch, tmp_path):
command = [
"mtt",
"export",
str(RESOURCES_PATH / "model-32-bit.ckpt"),
str(MODEL_PATH.with_suffix(".ckpt")),
"-o=model-32-bit-with-metadata.ckpt",
"--metadata=metadata.yaml",
]
Expand Down
Loading
Loading