From 1e406a8906f4b3f0863008cfb3f5b2ca4f97f542 Mon Sep 17 00:00:00 2001 From: Petr Date: Thu, 29 Jan 2026 17:48:10 +0100 Subject: [PATCH 1/7] added first version of calc_eos script --- .../calc_equation_of_state.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py diff --git a/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py new file mode 100644 index 000000000..d8429e29f --- /dev/null +++ b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py @@ -0,0 +1,105 @@ +"""Run calculations for EOS tests.""" + +from __future__ import annotations +from copy import copy +from pathlib import Path +from typing import Any +from datetime import datetime + +from ase import units +from ase.io import read, write + +import pandas as pd +import numpy as np +import pytest + +from ml_peg.calcs.utils.utils import download_s3_data +from ml_peg.models.get_models import load_models +from ml_peg.models.models import current_models + + +from ase.lattice.cubic import SimpleCubicFactory, \ + FaceCenteredCubic, BodyCenteredCubic + +MODELS = load_models(current_models) + +DATA_PATH = Path(__file__).parent / "data" +OUT_PATH = Path(__file__).parent / "outputs" + + + +class A15Factory(SimpleCubicFactory): + "A factory for creating A15 lattices." + xtal_name = "A15" + bravais_basis = [[0, 0, 0], + [0.5, 0.5, 0.5], + [0.5, 0.25, 0.0], + [0.5, 0.75, 0.0], + [0.0, 0.5, 0.25], + [0.0, 0.5, 0.75], + [0.25, 0.0, 0.5], + [0.75, 0.0, 0.5]] + + +A15 = A15Factory() + +lattices = {"BCC": BodyCenteredCubic, + "FCC": FaceCenteredCubic, + "A15": A15} + + +def equation_of_state(calc, lattice, symbol="W", size = (2, 2, 2), + volumes_per_atoms=np.linspace(12, 22, 10, endpoint=False)): + """Compute the equation of state for a given element and lattice. + """ + + # dummy call to have calc_num_atoms available + lattice(symbol="W", latticeconstant=3.16) + lattice_constants = (volumes_per_atoms * lattice.calc_num_atoms()) ** (1 / 3) + + structures = [lattice(latticeconstant=lc, size=size, symbol=symbol) for lc in lattice_constants] + for structure in structures: + structure.calc = calc + + energies = [structure.get_potential_energy() / len(structure) for structure in structures] + + return np.array(lattice_constants), np.array(energies) + + +@pytest.mark.parametrize("mlip", MODELS.items()) +def test_equation_of_state(mlip: tuple[str, Any]) -> None: + """Test equation of state calculation. + For the moment only for three BCC metals""" + + model_name, model = mlip + calc = model.get_calculator() + + volumes_per_atoms = np.linspace(12, 22, 100, endpoint=False) + results = {"V/atom": volumes_per_atoms} + + elements = ["W", "Mo", "Nb"] + + for element in elements: + for lattice_name, lattice in lattices.items(): + start_time = datetime.now() + print(f"Start time for {lattice_name} @ {model_name}: {start_time}") + lattice_constants, energies = equation_of_state( + calc, lattice, symbol=element, volumes_per_atoms=volumes_per_atoms + ) + end_time = datetime.now() + duration = end_time - start_time + hours, remainder = divmod(duration.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + print(f"End time for {lattice_name} @ {model_name}: {end_time}") + print(f"Duration for {lattice_name} @ {model_name}: {hours} hours {minutes} minutes {seconds} seconds") + print(duration) + + + results[f"{element}_{lattice_name}_a"] = lattice_constants + results[f"{element}_{lattice_name}_E"] = energies + + write_dir = OUT_PATH / model_name + df = pd.DataFrame(results) + output_file = write_dir / "eos_results.csv" + write_dir.mkdir(parents=True, exist_ok=True) + df.to_csv(output_file, index=False) \ No newline at end of file From 3254cf495ec180b0a2cca20e14801e775201b080 Mon Sep 17 00:00:00 2001 From: Petr Date: Fri, 30 Jan 2026 11:22:45 +0100 Subject: [PATCH 2/7] split results to files per element --- .../equation_of_state/calc_equation_of_state.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py index d8429e29f..4daa9c66f 100644 --- a/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py +++ b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py @@ -75,17 +75,19 @@ def test_equation_of_state(mlip: tuple[str, Any]) -> None: calc = model.get_calculator() volumes_per_atoms = np.linspace(12, 22, 100, endpoint=False) - results = {"V/atom": volumes_per_atoms} - + elements = ["W", "Mo", "Nb"] for element in elements: + results = {"V/atom": volumes_per_atoms} for lattice_name, lattice in lattices.items(): start_time = datetime.now() print(f"Start time for {lattice_name} @ {model_name}: {start_time}") + lattice_constants, energies = equation_of_state( calc, lattice, symbol=element, volumes_per_atoms=volumes_per_atoms ) + end_time = datetime.now() duration = end_time - start_time hours, remainder = divmod(duration.seconds, 3600) @@ -98,8 +100,8 @@ def test_equation_of_state(mlip: tuple[str, Any]) -> None: results[f"{element}_{lattice_name}_a"] = lattice_constants results[f"{element}_{lattice_name}_E"] = energies - write_dir = OUT_PATH / model_name - df = pd.DataFrame(results) - output_file = write_dir / "eos_results.csv" - write_dir.mkdir(parents=True, exist_ok=True) - df.to_csv(output_file, index=False) \ No newline at end of file + write_dir = OUT_PATH / model_name + df = pd.DataFrame(results) + output_file = write_dir / f"{element}_eos_results.csv" + write_dir.mkdir(parents=True, exist_ok=True) + df.to_csv(output_file, index=False) \ No newline at end of file From dd58ee3641773ddf9114e4e57e89121b49336267 Mon Sep 17 00:00:00 2001 From: Petr Date: Fri, 30 Jan 2026 13:38:00 +0100 Subject: [PATCH 3/7] added reference data --- .../calc_equation_of_state.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py index 4daa9c66f..6431482c5 100644 --- a/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py +++ b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any from datetime import datetime +from glob import glob from ase import units from ase.io import read, write @@ -23,11 +24,10 @@ MODELS = load_models(current_models) -DATA_PATH = Path(__file__).parent / "data" +DATA_PATH = Path(__file__).parent / "../../../../inputs/bulk_crystal/equation_of_state/" OUT_PATH = Path(__file__).parent / "outputs" - class A15Factory(SimpleCubicFactory): "A factory for creating A15 lattices." xtal_name = "A15" @@ -74,15 +74,26 @@ def test_equation_of_state(mlip: tuple[str, Any]) -> None: model_name, model = mlip calc = model.get_calculator() - volumes_per_atoms = np.linspace(12, 22, 100, endpoint=False) - - elements = ["W", "Mo", "Nb"] + fns = list(DATA_PATH.glob("*DFT*")) + + + for fn in fns: + element = fn.name.split("_")[0] + print(f"Starting EOS calculations for {element} with model {model_name}") - for element in elements: + dft_data = pd.read_csv(fn, comment="#") + + volumes_per_atoms = np.linspace(np.round(dft_data[dft_data.columns[0]].min() * 0.95), + np.round(dft_data[dft_data.columns[0]].max() * 1.05), 50, endpoint=False) results = {"V/atom": volumes_per_atoms} - for lattice_name, lattice in lattices.items(): + + phases = [col.split("_")[1] for col in dft_data.columns if "Delta" in col] + + for phase in phases: + assert phase in lattices, f"Lattice {phase} not implemented for EOS test." + lattice = lattices[phase] start_time = datetime.now() - print(f"Start time for {lattice_name} @ {model_name}: {start_time}") + print(f"Start time for {phase} @ {model_name}: {start_time}") lattice_constants, energies = equation_of_state( calc, lattice, symbol=element, volumes_per_atoms=volumes_per_atoms @@ -92,13 +103,13 @@ def test_equation_of_state(mlip: tuple[str, Any]) -> None: duration = end_time - start_time hours, remainder = divmod(duration.seconds, 3600) minutes, seconds = divmod(remainder, 60) - print(f"End time for {lattice_name} @ {model_name}: {end_time}") - print(f"Duration for {lattice_name} @ {model_name}: {hours} hours {minutes} minutes {seconds} seconds") + print(f"End time for {phase} @ {model_name}: {end_time}") + print(f"Duration for {phase} @ {model_name}: {hours} hours {minutes} minutes {seconds} seconds") print(duration) - results[f"{element}_{lattice_name}_a"] = lattice_constants - results[f"{element}_{lattice_name}_E"] = energies + results[f"{phase}_a"] = lattice_constants + results[f"{phase}_E"] = energies write_dir = OUT_PATH / model_name df = pd.DataFrame(results) From 7de9e7a28e8afb6355a5210ee5d40b1d32d238c5 Mon Sep 17 00:00:00 2001 From: Petr Grigorev Date: Thu, 26 Mar 2026 17:51:49 +0100 Subject: [PATCH 4/7] added first version of metrics, analysis and app --- .../analyse_equation_of_state.py | 439 ++++++++++++++++++ .../equation_of_state/metrics.yml | 19 + .../app_equation_of_state.py | 59 +++ .../calc_equation_of_state.py | 126 +++-- 4 files changed, 597 insertions(+), 46 deletions(-) create mode 100644 ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py create mode 100644 ml_peg/analysis/bulk_crystal/equation_of_state/metrics.yml create mode 100644 ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py diff --git a/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py b/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py new file mode 100644 index 000000000..8d0dd7573 --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py @@ -0,0 +1,439 @@ +"""Analyse equation of state benchmark.""" + +from __future__ import annotations + +from pathlib import Path + +from ase.eos import EquationOfState, birchmurnaghan +import numpy as np +import pandas as pd +import pytest + +from ml_peg.analysis.utils.decorators import build_table +from ml_peg.analysis.utils.utils import load_metrics_config +from ml_peg.app import APP_ROOT +from ml_peg.calcs import CALCS_ROOT +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +MODELS = get_model_names(current_models) +CALC_PATH = CALCS_ROOT / "bulk_crystal" / "equation_of_state" / "outputs" +OUT_PATH = APP_ROOT / "data" / "bulk_crystal" / "equation_of_state" + +DATA_PATH = Path(__file__).parent / "../../../../inputs/bulk_crystal/equation_of_state/" + +METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml") +DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config( + METRICS_CONFIG_PATH +) + +ELEMENTS = [f.name.split("_")[0] for f in DATA_PATH.glob("*DFT*")] + + +def _fit_bm_clean(volumes: np.ndarray, energies: np.ndarray) -> tuple | None: + """ + Fit a Birch-Murnaghan EOS, ignoring non-finite points. + + Parameters + ---------- + volumes + Per-atom volumes in ų. + energies + Per-atom energies in eV. + + Returns + ------- + tuple or None + ``eos.eos_parameters`` = ``(E0, B0, BP, V0)`` on success, or ``None`` + if fewer than 4 finite points exist or fitting fails. + """ + mask = np.isfinite(volumes) & np.isfinite(energies) + v = np.asarray(volumes)[mask] + e = np.asarray(energies)[mask] + if v.size < 4: + return None + try: + eos = EquationOfState(v, e, eos="birchmurnaghan") + eos.fit() + return eos.eos_parameters + except Exception: + return None + + +def calc_delta( + data_f: dict[str, float], + data_w: dict[str, float], + useasymm: bool, + vi: float, + vf: float, +) -> tuple[float, float, float]: + """ + Calculate the Delta metric (meV/atom) between two EOS curves. + + Expects B0 in ASE-native units (eV/A^3), as returned directly by + ``EquationOfState.eos_parameters``. No GPa conversion is applied. + + Adapted from ``calcDelta.py``, supplementary material from: + K. Lejaeghere, V. Van Speybroeck, G. Van Oost, and S. Cottenier: + "Error estimates for solid-state density-functional theory predictions: + an overview by means of the ground-state elemental crystals" + Crit. Rev. Solid State (2014). Open access: http://arxiv.org/abs/1204.2733 + + Parameters + ---------- + data_f + Reference EOS parameters: ``{"V0": float, "B0": float, "BP": float}``. + B0 in eV/A^3 (ASE units). + data_w + Model EOS parameters with the same keys. + useasymm + If ``True``, normalise Delta1 by the reference V0/B0 only; if + ``False`` use the average of reference and model values. + vi + Lower volume integration bound (A^3/atom). + vf + Upper volume integration bound (A^3/atom). + + Returns + ------- + tuple[float, float, float] + ``(delta, deltarel, delta1)`` where ``delta`` is in meV/atom. + """ + v0w = data_w["V0"] + b0w = data_w["B0"] + b1w = data_w["BP"] + + v0f = data_f["V0"] + b0f = data_f["B0"] + b1f = data_f["BP"] + + vref = 30.0 + bref = 100.0 # eV/A^3, matching ASE-native units + + a3f = 9.0 * v0f**3.0 * b0f / 16.0 * (b1f - 4.0) + a2f = 9.0 * v0f ** (7.0 / 3.0) * b0f / 16.0 * (14.0 - 3.0 * b1f) + a1f = 9.0 * v0f ** (5.0 / 3.0) * b0f / 16.0 * (3.0 * b1f - 16.0) + a0f = 9.0 * v0f * b0f / 16.0 * (6.0 - b1f) + + a3w = 9.0 * v0w**3.0 * b0w / 16.0 * (b1w - 4.0) + a2w = 9.0 * v0w ** (7.0 / 3.0) * b0w / 16.0 * (14.0 - 3.0 * b1w) + a1w = 9.0 * v0w ** (5.0 / 3.0) * b0w / 16.0 * (3.0 * b1w - 16.0) + a0w = 9.0 * v0w * b0w / 16.0 * (6.0 - b1w) + + x = [0.0] * 7 + x[0] = (a0f - a0w) ** 2 + x[1] = 6.0 * (a1f - a1w) * (a0f - a0w) + x[2] = -3.0 * (2.0 * (a2f - a2w) * (a0f - a0w) + (a1f - a1w) ** 2.0) + x[3] = -2.0 * (a3f - a3w) * (a0f - a0w) - 2.0 * (a2f - a2w) * (a1f - a1w) + x[4] = -3.0 / 5.0 * (2.0 * (a3f - a3w) * (a1f - a1w) + (a2f - a2w) ** 2.0) + x[5] = -6.0 / 7.0 * (a3f - a3w) * (a2f - a2w) + x[6] = -1.0 / 3.0 * (a3f - a3w) ** 2.0 + + y = [0.0] * 7 + y[0] = (a0f + a0w) ** 2 / 4.0 + y[1] = 3.0 * (a1f + a1w) * (a0f + a0w) / 2.0 + y[2] = -3.0 * (2.0 * (a2f + a2w) * (a0f + a0w) + (a1f + a1w) ** 2.0) / 4.0 + y[3] = -(a3f + a3w) * (a0f + a0w) / 2.0 - (a2f + a2w) * (a1f + a1w) / 2.0 + y[4] = -3.0 / 20.0 * (2.0 * (a3f + a3w) * (a1f + a1w) + (a2f + a2w) ** 2.0) + y[5] = -3.0 / 14.0 * (a3f + a3w) * (a2f + a2w) + y[6] = -1.0 / 12.0 * (a3f + a3w) ** 2.0 + + fi = 0.0 + ff = 0.0 + gi = 0.0 + gf = 0.0 + for n in range(7): + fi += x[n] * vi ** (-(2.0 * n - 3.0) / 3.0) + ff += x[n] * vf ** (-(2.0 * n - 3.0) / 3.0) + gi += y[n] * vi ** (-(2.0 * n - 3.0) / 3.0) + gf += y[n] * vf ** (-(2.0 * n - 3.0) / 3.0) + + delta = 1000.0 * np.sqrt((ff - fi) / (vf - vi)) + deltarel = 100.0 * np.sqrt((ff - fi) / (gf - gi)) + if useasymm: + delta1 = delta / v0w / b0w * vref * bref + else: + delta1 = delta / (v0w + v0f) / (b0w + b0f) * 4.0 * vref * bref + + return delta, deltarel, delta1 + + +def _phase_metrics_from_eos_mev( + dft_data: pd.DataFrame, model_data: pd.DataFrame +) -> tuple[float, float]: + """ + Compute phase-stability metrics relative to the reference phase. + + Parameters + ---------- + dft_data + DFT reference DataFrame with columns ``V/atom_{phase}`` and + ``Delta_{phase}_E`` for each phase. + model_data + Model results DataFrame with columns ``V/atom`` and ``{phase}_E`` + for each phase. + + Returns + ------- + tuple[float, float] + ``(PhaseDiffEOS_MAE_meV, CorrectStability_pct)``. Returns + ``(nan, nan)`` if any BM fit fails. + """ + phases = [ + col.split("_")[1] + for col in dft_data.columns + if col.startswith("Delta_") and col.endswith("_E") + ] + + dft_fits = [ + _fit_bm_clean( + dft_data[f"V/atom_{phase}"].values, + dft_data[f"Delta_{phase}_E"].values, + ) + for phase in phases + ] + model_fits = [ + _fit_bm_clean(model_data["V/atom"].values, model_data[f"{phase}_E"].values) + for phase in phases + ] + + if any(fit is None for fit in dft_fits + model_fits): + return np.nan, np.nan + + ref_volumes = dft_data[f"V/atom_{phases[0]}"].values + ref_volumes = ref_volumes[np.isfinite(ref_volumes)] + v_grid = np.linspace(ref_volumes.min(), ref_volumes.max(), 80) + + dft_deltas = np.vstack( + [ + birchmurnaghan(v_grid, *dft_fit) - birchmurnaghan(v_grid, *dft_fits[0]) + for dft_fit in dft_fits[1:] + ] + ) + model_deltas = np.vstack( + [ + birchmurnaghan(v_grid, *model_fit) - birchmurnaghan(v_grid, *model_fits[0]) + for model_fit in model_fits[1:] + ] + ) + + mae_ev = float(np.mean(np.abs(dft_deltas - model_deltas))) + correct_stability_pct = 100.0 * float(np.mean(np.all(model_deltas > 0, axis=0))) + + return 1000.0 * mae_ev, correct_stability_pct + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def eos_stats() -> dict[tuple[str, str], dict[str, float]]: + """ + Compute all three metrics for every model-element pair. + + Returns + ------- + dict[tuple[str, str], dict[str, float]] + Mapping of ``(model_name, element)`` to ``{"Delta", + "PhaseDiffEOS_MAE_meV", "CorrectStability_pct"}``. + """ + OUT_PATH.mkdir(parents=True, exist_ok=True) + results: dict[tuple[str, str], dict[str, float]] = {} + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if not model_dir.exists(): + continue + + for element in ELEMENTS: + model_csv = model_dir / f"{element}_eos_results.csv" + dft_csv = DATA_PATH / f"{element}_eos_DFT.csv" + if not model_csv.exists() or not dft_csv.exists(): + continue + + model_data = pd.read_csv(model_csv) + dft_data = pd.read_csv(dft_csv, comment="#") + + phases = [ + col.split("_")[1] + for col in dft_data.columns + if col.startswith("Delta_") and col.endswith("_E") + ] + # BM fit for reference phase (Delta metric) + ref_phase = phases[0] # Assuming the first phase is the reference + ref_params = _fit_bm_clean( + model_data["V/atom"].values, model_data[f"{ref_phase}_E"].values + ) + dft_ref_params = _fit_bm_clean( + dft_data[f"V/atom_{ref_phase}"].values, + dft_data[f"Delta_{ref_phase}_E"].values, + ) + + if ref_params is not None and dft_ref_params is not None: + # eos_parameters: (E0, B0, BP, V0) + model_bm = { + "V0": ref_params[-1], + "B0": ref_params[-3], + "BP": ref_params[-2], + } + dft_bm = { + "V0": dft_ref_params[-1], + "B0": dft_ref_params[-3], + "BP": dft_ref_params[-2], + } + volumes = model_data["V/atom"] + delta, _, _ = calc_delta( + dft_bm, + model_bm, + useasymm=False, + vi=float(volumes.iloc[0]), + vf=float(volumes.iloc[-1]), + ) + else: + delta = np.nan + + phase_diff_mae, correct_stability = _phase_metrics_from_eos_mev( + dft_data, model_data + ) + + results[(model_name, element)] = { + "Δ": delta, + "Phase energy": phase_diff_mae, + "Phase stability": correct_stability, + } + + return results + + +@pytest.fixture +def delta( + eos_stats: dict[tuple[str, str], dict[str, float]], +) -> dict[str, float]: + """ + Mean Delta (meV/atom) across elements for each model. + + Parameters + ---------- + eos_stats + Per-(model, element) metric values. + + Returns + ------- + dict[str, float] + Mean Delta per model. + """ + results: dict[str, float] = {} + for model_name in MODELS: + values = [ + eos_stats[(model_name, el)]["Δ"] + for el in ELEMENTS + if (model_name, el) in eos_stats + ] + results[model_name] = float(np.nanmean(values)) if values else None + return results + + +@pytest.fixture +def phase_diff_eos_mae( + eos_stats: dict[tuple[str, str], dict[str, float]], +) -> dict[str, float]: + """ + Mean PhaseDiffEOS MAE (meV/atom) across elements for each model. + + Parameters + ---------- + eos_stats + Per-(model, element) metric values. + + Returns + ------- + dict[str, float] + Mean PhaseDiffEOS_MAE_meV per model. + """ + results: dict[str, float] = {} + for model_name in MODELS: + values = [ + eos_stats[(model_name, el)]["Phase energy"] + for el in ELEMENTS + if (model_name, el) in eos_stats + ] + results[model_name] = float(np.nanmean(values)) if values else None + return results + + +@pytest.fixture +def correct_stability( + eos_stats: dict[tuple[str, str], dict[str, float]], +) -> dict[str, float]: + """ + Mean CorrectStability (%) across elements for each model. + + Parameters + ---------- + eos_stats + Per-(model, element) metric values. + + Returns + ------- + dict[str, float] + Mean Phase stability per model. + """ + results: dict[str, float] = {} + for model_name in MODELS: + values = [ + eos_stats[(model_name, el)]["Phase stability"] + for el in ELEMENTS + if (model_name, el) in eos_stats + ] + results[model_name] = float(np.nanmean(values)) if values else None + return results + + +@pytest.fixture +@build_table( + filename=OUT_PATH / "eos_metrics_table.json", + metric_tooltips=DEFAULT_TOOLTIPS, + thresholds=DEFAULT_THRESHOLDS, + weights=DEFAULT_WEIGHTS, +) +def metrics( + delta: dict[str, float], + phase_diff_eos_mae: dict[str, float], + correct_stability: dict[str, float], +) -> dict[str, dict]: + """ + All EOS benchmark metrics. + + Parameters + ---------- + delta + Mean Delta per model (meV/atom). + phase_diff_eos_mae + Mean PhaseDiffEOS MAE per model (meV/atom). + correct_stability + Mean Phase stability per model (%). + + Returns + ------- + dict[str, dict] + Mapping of metric name to per-model value dicts. + """ + return { + "Δ": delta, + "Phase energy": phase_diff_eos_mae, + "Phase stability": correct_stability, + } + + +def test_equation_of_state(metrics: dict[str, dict]) -> None: + """ + Run EOS benchmark analysis. + + Parameters + ---------- + metrics + All EOS benchmark metric values. + """ + return diff --git a/ml_peg/analysis/bulk_crystal/equation_of_state/metrics.yml b/ml_peg/analysis/bulk_crystal/equation_of_state/metrics.yml new file mode 100644 index 000000000..5f22e6139 --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/equation_of_state/metrics.yml @@ -0,0 +1,19 @@ +metrics: + Δ: + good: 0.0 + bad: 200.0 + unit: meV/atom + tooltip: RMS energy difference per atom between the two EOS curves averaged over the sampled volume range. + level_of_theory: PBE + Phase energy: + good: 0.0 + bad: 200.0 + unit: meV/atom + tooltip: Mean absolute error of the phase energy differences relative to the reference phase between model and DFT. + level_of_theory: PBE + Phase stability: + good: 100.0 + bad: 0.0 + unit: '%' + tooltip: Percentage of correctly predicted stable phases. + level_of_theory: PBE diff --git a/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py new file mode 100644 index 000000000..2e2cb2d7c --- /dev/null +++ b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py @@ -0,0 +1,59 @@ +"""Run equation of state benchmark app.""" + +from __future__ import annotations + +from dash import Dash +from dash.html import Div + +from ml_peg.app import APP_ROOT +from ml_peg.app.base_app import BaseApp +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +# Get all models +MODELS = get_model_names(current_models) +BENCHMARK_NAME = "Equation of State" +DOCS_URL = "https://ddmms.github.io/ml-peg/user_guide/benchmarks/bulk_crystal.html#equation-of-state" +DATA_PATH = APP_ROOT / "data" / "bulk_crystal" / "equation_of_state" + + +class EquationOfStateApp(BaseApp): + """Equation of State benchmark app layout and callbacks.""" + + def register_callbacks(self) -> None: + """Register callbacks to app.""" + # Build plots for models with data (read_density_plot_for_model + # returns None for models without data) + return + + +def get_app() -> EquationOfStateApp: + """ + Get equation of state benchmark app layout and callback registration. + + Returns + ------- + EquationOfStateApp + Benchmark layout and callback registration. + """ + return EquationOfStateApp( + name=BENCHMARK_NAME, + description=( + "Performance when calculating the equation of state for different " + "bulk crystal (W, Mo, Nb) structures " + "scomapred to PBE data from literature." + ), + docs_url=DOCS_URL, + table_path=DATA_PATH / "eos_metrics_table.json", + extra_components=[ + Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), + ], + ) + + +if __name__ == "__main__": + full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent) + equation_of_state_app = get_app() + full_app.layout = equation_of_state_app.layout + equation_of_state_app.register_callbacks() + full_app.run(port=8054, debug=True) diff --git a/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py index 6431482c5..ccf7b5c11 100644 --- a/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py +++ b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py @@ -1,27 +1,18 @@ """Run calculations for EOS tests.""" from __future__ import annotations -from copy import copy + from pathlib import Path from typing import Any -from datetime import datetime -from glob import glob - -from ase import units -from ase.io import read, write -import pandas as pd +from ase.lattice.cubic import BodyCenteredCubic, FaceCenteredCubic, SimpleCubicFactory import numpy as np +import pandas as pd import pytest -from ml_peg.calcs.utils.utils import download_s3_data from ml_peg.models.get_models import load_models from ml_peg.models.models import current_models - -from ase.lattice.cubic import SimpleCubicFactory, \ - FaceCenteredCubic, BodyCenteredCubic - MODELS = load_models(current_models) DATA_PATH = Path(__file__).parent / "../../../../inputs/bulk_crystal/equation_of_state/" @@ -29,85 +20,128 @@ class A15Factory(SimpleCubicFactory): - "A factory for creating A15 lattices." + """A factory for creating A15 lattices.""" + xtal_name = "A15" - bravais_basis = [[0, 0, 0], - [0.5, 0.5, 0.5], - [0.5, 0.25, 0.0], - [0.5, 0.75, 0.0], - [0.0, 0.5, 0.25], - [0.0, 0.5, 0.75], - [0.25, 0.0, 0.5], - [0.75, 0.0, 0.5]] + bravais_basis = [ + [0, 0, 0], + [0.5, 0.5, 0.5], + [0.5, 0.25, 0.0], + [0.5, 0.75, 0.0], + [0.0, 0.5, 0.25], + [0.0, 0.5, 0.75], + [0.25, 0.0, 0.5], + [0.75, 0.0, 0.5], + ] A15 = A15Factory() -lattices = {"BCC": BodyCenteredCubic, - "FCC": FaceCenteredCubic, - "A15": A15} +lattices = {"BCC": BodyCenteredCubic, "FCC": FaceCenteredCubic, "A15": A15} -def equation_of_state(calc, lattice, symbol="W", size = (2, 2, 2), - volumes_per_atoms=np.linspace(12, 22, 10, endpoint=False)): - """Compute the equation of state for a given element and lattice. +def equation_of_state( + calc, + lattice, + volumes_per_atoms, + symbol="W", + size=(2, 2, 2), +): + """ + Compute the equation of state for a given element and lattice. + + Parameters + ---------- + calc + ASE calculator to use for energy calculations. + lattice + ASE lattice class to use for structure generation. + volumes_per_atoms + Array of volumes per atom to sample for the EOS curve. + symbol + Chemical symbol of the element to use for structure generation. + size + Size of the supercell to generate for each volume per atom. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Lattice constants (A) and energies (eV/atom) arrays. """ - # dummy call to have calc_num_atoms available - lattice(symbol="W", latticeconstant=3.16) + lattice(symbol="W", latticeconstant=3.16) lattice_constants = (volumes_per_atoms * lattice.calc_num_atoms()) ** (1 / 3) - structures = [lattice(latticeconstant=lc, size=size, symbol=symbol) for lc in lattice_constants] + structures = [ + lattice(latticeconstant=lc, size=size, symbol=symbol) + for lc in lattice_constants + ] for structure in structures: structure.calc = calc - energies = [structure.get_potential_energy() / len(structure) for structure in structures] + energies = [ + structure.get_potential_energy() / len(structure) for structure in structures + ] return np.array(lattice_constants), np.array(energies) @pytest.mark.parametrize("mlip", MODELS.items()) def test_equation_of_state(mlip: tuple[str, Any]) -> None: - """Test equation of state calculation. - For the moment only for three BCC metals""" + """ + Test equation of state calculation for three BCC metals. + Parameters + ---------- + mlip + Tuple of (model_name, model) as provided by pytest parametrize. + """ model_name, model = mlip calc = model.get_calculator() fns = list(DATA_PATH.glob("*DFT*")) - - + for fn in fns: element = fn.name.split("_")[0] print(f"Starting EOS calculations for {element} with model {model_name}") - dft_data = pd.read_csv(fn, comment="#") + dft_data = pd.read_csv(fn, comment="#") - volumes_per_atoms = np.linspace(np.round(dft_data[dft_data.columns[0]].min() * 0.95), - np.round(dft_data[dft_data.columns[0]].max() * 1.05), 50, endpoint=False) + volumes_per_atoms = np.linspace( + np.round(dft_data[dft_data.columns[0]].min() * 0.95), + np.round(dft_data[dft_data.columns[0]].max() * 1.05), + 50, + endpoint=False, + ) results = {"V/atom": volumes_per_atoms} - + phases = [col.split("_")[1] for col in dft_data.columns if "Delta" in col] for phase in phases: assert phase in lattices, f"Lattice {phase} not implemented for EOS test." lattice = lattices[phase] + """' start_time = datetime.now() print(f"Start time for {phase} @ {model_name}: {start_time}") - + """ lattice_constants, energies = equation_of_state( - calc, lattice, symbol=element, volumes_per_atoms=volumes_per_atoms + calc, + lattice, + volumes_per_atoms, + symbol=element, ) - + """ end_time = datetime.now() duration = end_time - start_time hours, remainder = divmod(duration.seconds, 3600) minutes, seconds = divmod(remainder, 60) print(f"End time for {phase} @ {model_name}: {end_time}") - print(f"Duration for {phase} @ {model_name}: {hours} hours {minutes} minutes {seconds} seconds") + print( + f"Duration for {phase} @ {model_name}: " + f"{hours} hours {minutes} minutes {seconds} seconds" + ) print(duration) - - + """ results[f"{phase}_a"] = lattice_constants results[f"{phase}_E"] = energies @@ -115,4 +149,4 @@ def test_equation_of_state(mlip: tuple[str, Any]) -> None: df = pd.DataFrame(results) output_file = write_dir / f"{element}_eos_results.csv" write_dir.mkdir(parents=True, exist_ok=True) - df.to_csv(output_file, index=False) \ No newline at end of file + df.to_csv(output_file, index=False) From 08d923c5bb0d01c6df16f0505fcc5d9a7a93fa17 Mon Sep 17 00:00:00 2001 From: Petr Grigorev Date: Tue, 31 Mar 2026 16:39:13 +0200 Subject: [PATCH 5/7] added periodic table plots on click --- .../analyse_equation_of_state.py | 70 ++++++++++++++++++- .../app_equation_of_state.py | 28 +++++++- 2 files changed, 92 insertions(+), 6 deletions(-) diff --git a/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py b/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py index 8d0dd7573..fc9fa08d4 100644 --- a/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py +++ b/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py @@ -9,7 +9,7 @@ import pandas as pd import pytest -from ml_peg.analysis.utils.decorators import build_table +from ml_peg.analysis.utils.decorators import build_table, plot_periodic_table from ml_peg.analysis.utils.utils import load_metrics_config from ml_peg.app import APP_ROOT from ml_peg.calcs import CALCS_ROOT @@ -237,7 +237,7 @@ def eos_stats() -> dict[tuple[str, str], dict[str, float]]: ------- dict[tuple[str, str], dict[str, float]] Mapping of ``(model_name, element)`` to ``{"Delta", - "PhaseDiffEOS_MAE_meV", "CorrectStability_pct"}``. + "phase_diff_mae", "correct_stability"}``. """ OUT_PATH.mkdir(parents=True, exist_ok=True) results: dict[tuple[str, str], dict[str, float]] = {} @@ -307,6 +307,65 @@ def eos_stats() -> dict[tuple[str, str], dict[str, float]]: return results +def get_metric_per_element(model, eos_stats, metric_name): + """ + Get a dictionary of metric values for each element for a given model. + + Parameters + ---------- + model + The name of the model to extract metrics for. + eos_stats + The full EOS statistics dictionary containing all models and elements. + metric_name + The name of the metric to extract (e.g., "Δ", + "Phase energy", "Phase stability"). + + Returns + ------- + dict[str, float] + A dictionary mapping element symbols to their + corresponding metric values for the specified model. + """ + return { + el: eos_stats.get((model, el), {}).get(metric_name, np.nan) for el in ELEMENTS + } + + +@pytest.fixture +def periodic_tables( + eos_stats: dict[tuple[str, str], dict[str, float]], +) -> None: + """ + Write per-model periodic-table heatmaps for each EOS metric. + + Parameters + ---------- + eos_stats + Per-(model, element) metric values. + """ + file_suffixes = { + "Δ": "delta_periodic_table", + "Phase energy": "phase_energy_periodic_table", + "Phase stability": "phase_stability_periodic_table", + } + for model in MODELS: + for metric_name in ["Δ", "Phase energy", "Phase stability"]: + title = ( + f"{metric_name} (meV/atom)" + if metric_name != "Phase stability" + else f"{metric_name} (%)" + ) + colorbar_title = title + file_suffix = file_suffixes[metric_name] + values = get_metric_per_element(model, eos_stats, metric_name) + plot_periodic_table( + title=f"{title} - {model}", + colorbar_title=colorbar_title, + filename=str(OUT_PATH / model / f"{file_suffix}.json"), + )(lambda v=values: v)() + + @pytest.fixture def delta( eos_stats: dict[tuple[str, str], dict[str, float]], @@ -427,7 +486,10 @@ def metrics( } -def test_equation_of_state(metrics: dict[str, dict]) -> None: +def test_equation_of_state( + metrics: dict[str, dict], + periodic_tables: None, +) -> None: """ Run EOS benchmark analysis. @@ -435,5 +497,7 @@ def test_equation_of_state(metrics: dict[str, dict]) -> None: ---------- metrics All EOS benchmark metric values. + periodic_tables + Per-model periodic-table heatmaps (side-effect only). """ return diff --git a/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py index 2e2cb2d7c..6a82a3aee 100644 --- a/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py +++ b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py @@ -7,6 +7,8 @@ from ml_peg.app import APP_ROOT from ml_peg.app.base_app import BaseApp +from ml_peg.app.utils.build_callbacks import plot_from_table_cell +from ml_peg.app.utils.load import read_plot from ml_peg.models.get_models import get_model_names from ml_peg.models.models import current_models @@ -22,9 +24,29 @@ class EquationOfStateApp(BaseApp): def register_callbacks(self) -> None: """Register callbacks to app.""" - # Build plots for models with data (read_density_plot_for_model - # returns None for models without data) - return + _metrics = [ + ("Δ", "delta_periodic_table"), + ("Phase energy", "phase_energy_periodic_table"), + ("Phase stability", "phase_stability_periodic_table"), + ] + cell_to_plot = {} + for model in MODELS: + plots = {} + for column_id, file_suffix in _metrics: + path = DATA_PATH / model / f"{file_suffix}.json" + if path.exists(): + plots[column_id] = read_plot( + filename=path, + id=f"{BENCHMARK_NAME}-{model}-{file_suffix}", + ) + if plots: + cell_to_plot[model] = plots + + plot_from_table_cell( + table_id=self.table_id, + plot_id=f"{BENCHMARK_NAME}-figure-placeholder", + cell_to_plot=cell_to_plot, + ) def get_app() -> EquationOfStateApp: From 7ecd59367b5a4b6f18367d2da63f8a2b2021b0d0 Mon Sep 17 00:00:00 2001 From: Petr Grigorev Date: Tue, 31 Mar 2026 17:38:33 +0200 Subject: [PATCH 6/7] added eos graphs on the element click --- .../app_equation_of_state.py | 155 ++++++++++++++++-- 1 file changed, 140 insertions(+), 15 deletions(-) diff --git a/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py index 6a82a3aee..1a5ea3f96 100644 --- a/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py +++ b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py @@ -2,21 +2,101 @@ from __future__ import annotations -from dash import Dash +from pathlib import Path + +from dash import ALL, Dash, Input, Output, callback, callback_context +from dash.dcc import Graph +from dash.exceptions import PreventUpdate from dash.html import Div +import pandas as pd +from plotly.colors import qualitative +import plotly.graph_objects as go +from plotly.io import read_json from ml_peg.app import APP_ROOT from ml_peg.app.base_app import BaseApp from ml_peg.app.utils.build_callbacks import plot_from_table_cell -from ml_peg.app.utils.load import read_plot +from ml_peg.calcs import CALCS_ROOT from ml_peg.models.get_models import get_model_names from ml_peg.models.models import current_models -# Get all models MODELS = get_model_names(current_models) BENCHMARK_NAME = "Equation of State" -DOCS_URL = "https://ddmms.github.io/ml-peg/user_guide/benchmarks/bulk_crystal.html#equation-of-state" +DOCS_URL = ( + "https://ddmms.github.io/ml-peg/user_guide/" + "benchmarks/bulk_crystal.html#equation-of-state" +) DATA_PATH = APP_ROOT / "data" / "bulk_crystal" / "equation_of_state" +CALC_PATH = CALCS_ROOT / "bulk_crystal" / "equation_of_state" / "outputs" +INPUT_PATH = Path(__file__).parents[4] / "inputs" / "bulk_crystal" / "equation_of_state" +_PT_TYPE = "eos-periodic-table" +_EOS_CURVE_ID = f"{BENCHMARK_NAME}-eos-curve" +_METRICS = [ + ("\u0394", "delta_periodic_table"), + ("Phase energy", "phase_energy_periodic_table"), + ("Phase stability", "phase_stability_periodic_table"), +] + + +def _make_eos_figure(model: str, element: str) -> go.Figure | None: + """ + Create an equation of state figure for a given model and element. + + Parameters + ---------- + model : str + The model name. + element : str + The element name. + + Returns + ------- + go.Figure | None + The equation of state figure or None if the data is not available. + """ + model_csv = CALC_PATH / model / f"{element}_eos_results.csv" + dft_csv = INPUT_PATH / f"{element}_eos_DFT.csv" + if not model_csv.exists() or not dft_csv.exists(): + return None + model_data = pd.read_csv(model_csv) + dft_data = pd.read_csv(dft_csv, comment="#") + phases = [ + col.split("_")[1] + for col in dft_data.columns + if col.startswith("Delta_") and col.endswith("_E") + ] + colours = qualitative.D3 + fig = go.Figure() + for i, phase in enumerate(phases): + colour = colours[i % len(colours)] + dft_v = dft_data[f"V/atom_{phase}"].dropna() + dft_e = dft_data[f"Delta_{phase}_E"].loc[dft_v.index] + fig.add_trace( + go.Scatter( + x=dft_v, + y=dft_e, + mode="markers", + name=f"DFT {phase}", + marker={"symbol": "x", "color": colour, "size": 8}, + ) + ) + model_v = model_data["V/atom"] + model_delta_e = model_data[f"{phase}_E"] - model_data[f"{phases[0]}_E"].min() + fig.add_trace( + go.Scatter( + x=model_v, + y=model_delta_e, + mode="lines", + name=f"{model} {phase}", + line={"color": colour}, + ) + ) + fig.update_layout( + title=f"EOS - {element} ({model})", + xaxis_title="Volume per atom (\u00c5\u00b3)", + yaxis_title="Energy per atom (eV)", + ) + return fig class EquationOfStateApp(BaseApp): @@ -24,21 +104,21 @@ class EquationOfStateApp(BaseApp): def register_callbacks(self) -> None: """Register callbacks to app.""" - _metrics = [ - ("Δ", "delta_periodic_table"), - ("Phase energy", "phase_energy_periodic_table"), - ("Phase stability", "phase_stability_periodic_table"), - ] cell_to_plot = {} for model in MODELS: plots = {} - for column_id, file_suffix in _metrics: + for column_id, file_suffix in _METRICS: path = DATA_PATH / model / f"{file_suffix}.json" - if path.exists(): - plots[column_id] = read_plot( - filename=path, - id=f"{BENCHMARK_NAME}-{model}-{file_suffix}", - ) + if not path.exists(): + continue + plots[column_id] = Graph( + id={ + "type": _PT_TYPE, + "model": model, + "metric": file_suffix, + }, + figure=read_json(path), + ) if plots: cell_to_plot[model] = plots @@ -48,6 +128,50 @@ def register_callbacks(self) -> None: cell_to_plot=cell_to_plot, ) + @callback( + Output(_EOS_CURVE_ID, "children"), + Input( + {"type": _PT_TYPE, "model": ALL, "metric": ALL}, + "clickData", + ), + prevent_initial_call=True, + ) + def show_eos_curve(_): + """ + Show the equation of state curve for the clicked element and model. + + Parameters + ---------- + _ : Any + The click data from the periodic table graph. + The actual value is not used, but the callback + context is used to determine which cell was clicked. + + Returns + ------- + Div + The div containing the equation of state figure. + """ + ctx = callback_context + triggered_id = ctx.triggered_id + if not isinstance(triggered_id, dict): + raise PreventUpdate + click_data = ctx.triggered[0]["value"] + if not click_data: + raise PreventUpdate + points = click_data.get("points", []) + if not points: + raise PreventUpdate + text = points[0].get("text", "") + element = text.split("
")[0].strip() + if not element or len(element) > 3: + raise PreventUpdate + model = triggered_id["model"] + fig = _make_eos_figure(model, element) + if fig is None: + return Div(f"No data for {element} / {model}.") + return Div(Graph(figure=fig)) + def get_app() -> EquationOfStateApp: """ @@ -69,6 +193,7 @@ def get_app() -> EquationOfStateApp: table_path=DATA_PATH / "eos_metrics_table.json", extra_components=[ Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), + Div(id=_EOS_CURVE_ID), ], ) From f941b79e9c84c60126f7beffbca1818ce8ca7b40 Mon Sep 17 00:00:00 2001 From: Petr Grigorev Date: Tue, 31 Mar 2026 17:44:36 +0200 Subject: [PATCH 7/7] fixed typos/improved description --- .../equation_of_state/app_equation_of_state.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py index 1a5ea3f96..db34eb592 100644 --- a/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py +++ b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py @@ -21,7 +21,7 @@ from ml_peg.models.models import current_models MODELS = get_model_names(current_models) -BENCHMARK_NAME = "Equation of State" +BENCHMARK_NAME = "Equation of State (metals)" DOCS_URL = ( "https://ddmms.github.io/ml-peg/user_guide/" "benchmarks/bulk_crystal.html#equation-of-state" @@ -185,9 +185,8 @@ def get_app() -> EquationOfStateApp: return EquationOfStateApp( name=BENCHMARK_NAME, description=( - "Performance when calculating the equation of state for different " - "bulk crystal (W, Mo, Nb) structures " - "scomapred to PBE data from literature." + "Equation of state curves and phase stability for BCC metals " + "(W, Mo, Nb), benchmarked against PBE reference data." ), docs_url=DOCS_URL, table_path=DATA_PATH / "eos_metrics_table.json",