diff --git a/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py b/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py new file mode 100644 index 000000000..fc9fa08d4 --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/equation_of_state/analyse_equation_of_state.py @@ -0,0 +1,503 @@ +"""Analyse equation of state benchmark.""" + +from __future__ import annotations + +from pathlib import Path + +from ase.eos import EquationOfState, birchmurnaghan +import numpy as np +import pandas as pd +import pytest + +from ml_peg.analysis.utils.decorators import build_table, plot_periodic_table +from ml_peg.analysis.utils.utils import load_metrics_config +from ml_peg.app import APP_ROOT +from ml_peg.calcs import CALCS_ROOT +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +MODELS = get_model_names(current_models) +CALC_PATH = CALCS_ROOT / "bulk_crystal" / "equation_of_state" / "outputs" +OUT_PATH = APP_ROOT / "data" / "bulk_crystal" / "equation_of_state" + +DATA_PATH = Path(__file__).parent / "../../../../inputs/bulk_crystal/equation_of_state/" + +METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml") +DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config( + METRICS_CONFIG_PATH +) + +ELEMENTS = [f.name.split("_")[0] for f in DATA_PATH.glob("*DFT*")] + + +def _fit_bm_clean(volumes: np.ndarray, energies: np.ndarray) -> tuple | None: + """ + Fit a Birch-Murnaghan EOS, ignoring non-finite points. + + Parameters + ---------- + volumes + Per-atom volumes in ų. + energies + Per-atom energies in eV. + + Returns + ------- + tuple or None + ``eos.eos_parameters`` = ``(E0, B0, BP, V0)`` on success, or ``None`` + if fewer than 4 finite points exist or fitting fails. + """ + mask = np.isfinite(volumes) & np.isfinite(energies) + v = np.asarray(volumes)[mask] + e = np.asarray(energies)[mask] + if v.size < 4: + return None + try: + eos = EquationOfState(v, e, eos="birchmurnaghan") + eos.fit() + return eos.eos_parameters + except Exception: + return None + + +def calc_delta( + data_f: dict[str, float], + data_w: dict[str, float], + useasymm: bool, + vi: float, + vf: float, +) -> tuple[float, float, float]: + """ + Calculate the Delta metric (meV/atom) between two EOS curves. + + Expects B0 in ASE-native units (eV/A^3), as returned directly by + ``EquationOfState.eos_parameters``. No GPa conversion is applied. + + Adapted from ``calcDelta.py``, supplementary material from: + K. Lejaeghere, V. Van Speybroeck, G. Van Oost, and S. Cottenier: + "Error estimates for solid-state density-functional theory predictions: + an overview by means of the ground-state elemental crystals" + Crit. Rev. Solid State (2014). Open access: http://arxiv.org/abs/1204.2733 + + Parameters + ---------- + data_f + Reference EOS parameters: ``{"V0": float, "B0": float, "BP": float}``. + B0 in eV/A^3 (ASE units). + data_w + Model EOS parameters with the same keys. + useasymm + If ``True``, normalise Delta1 by the reference V0/B0 only; if + ``False`` use the average of reference and model values. + vi + Lower volume integration bound (A^3/atom). + vf + Upper volume integration bound (A^3/atom). + + Returns + ------- + tuple[float, float, float] + ``(delta, deltarel, delta1)`` where ``delta`` is in meV/atom. + """ + v0w = data_w["V0"] + b0w = data_w["B0"] + b1w = data_w["BP"] + + v0f = data_f["V0"] + b0f = data_f["B0"] + b1f = data_f["BP"] + + vref = 30.0 + bref = 100.0 # eV/A^3, matching ASE-native units + + a3f = 9.0 * v0f**3.0 * b0f / 16.0 * (b1f - 4.0) + a2f = 9.0 * v0f ** (7.0 / 3.0) * b0f / 16.0 * (14.0 - 3.0 * b1f) + a1f = 9.0 * v0f ** (5.0 / 3.0) * b0f / 16.0 * (3.0 * b1f - 16.0) + a0f = 9.0 * v0f * b0f / 16.0 * (6.0 - b1f) + + a3w = 9.0 * v0w**3.0 * b0w / 16.0 * (b1w - 4.0) + a2w = 9.0 * v0w ** (7.0 / 3.0) * b0w / 16.0 * (14.0 - 3.0 * b1w) + a1w = 9.0 * v0w ** (5.0 / 3.0) * b0w / 16.0 * (3.0 * b1w - 16.0) + a0w = 9.0 * v0w * b0w / 16.0 * (6.0 - b1w) + + x = [0.0] * 7 + x[0] = (a0f - a0w) ** 2 + x[1] = 6.0 * (a1f - a1w) * (a0f - a0w) + x[2] = -3.0 * (2.0 * (a2f - a2w) * (a0f - a0w) + (a1f - a1w) ** 2.0) + x[3] = -2.0 * (a3f - a3w) * (a0f - a0w) - 2.0 * (a2f - a2w) * (a1f - a1w) + x[4] = -3.0 / 5.0 * (2.0 * (a3f - a3w) * (a1f - a1w) + (a2f - a2w) ** 2.0) + x[5] = -6.0 / 7.0 * (a3f - a3w) * (a2f - a2w) + x[6] = -1.0 / 3.0 * (a3f - a3w) ** 2.0 + + y = [0.0] * 7 + y[0] = (a0f + a0w) ** 2 / 4.0 + y[1] = 3.0 * (a1f + a1w) * (a0f + a0w) / 2.0 + y[2] = -3.0 * (2.0 * (a2f + a2w) * (a0f + a0w) + (a1f + a1w) ** 2.0) / 4.0 + y[3] = -(a3f + a3w) * (a0f + a0w) / 2.0 - (a2f + a2w) * (a1f + a1w) / 2.0 + y[4] = -3.0 / 20.0 * (2.0 * (a3f + a3w) * (a1f + a1w) + (a2f + a2w) ** 2.0) + y[5] = -3.0 / 14.0 * (a3f + a3w) * (a2f + a2w) + y[6] = -1.0 / 12.0 * (a3f + a3w) ** 2.0 + + fi = 0.0 + ff = 0.0 + gi = 0.0 + gf = 0.0 + for n in range(7): + fi += x[n] * vi ** (-(2.0 * n - 3.0) / 3.0) + ff += x[n] * vf ** (-(2.0 * n - 3.0) / 3.0) + gi += y[n] * vi ** (-(2.0 * n - 3.0) / 3.0) + gf += y[n] * vf ** (-(2.0 * n - 3.0) / 3.0) + + delta = 1000.0 * np.sqrt((ff - fi) / (vf - vi)) + deltarel = 100.0 * np.sqrt((ff - fi) / (gf - gi)) + if useasymm: + delta1 = delta / v0w / b0w * vref * bref + else: + delta1 = delta / (v0w + v0f) / (b0w + b0f) * 4.0 * vref * bref + + return delta, deltarel, delta1 + + +def _phase_metrics_from_eos_mev( + dft_data: pd.DataFrame, model_data: pd.DataFrame +) -> tuple[float, float]: + """ + Compute phase-stability metrics relative to the reference phase. + + Parameters + ---------- + dft_data + DFT reference DataFrame with columns ``V/atom_{phase}`` and + ``Delta_{phase}_E`` for each phase. + model_data + Model results DataFrame with columns ``V/atom`` and ``{phase}_E`` + for each phase. + + Returns + ------- + tuple[float, float] + ``(PhaseDiffEOS_MAE_meV, CorrectStability_pct)``. Returns + ``(nan, nan)`` if any BM fit fails. + """ + phases = [ + col.split("_")[1] + for col in dft_data.columns + if col.startswith("Delta_") and col.endswith("_E") + ] + + dft_fits = [ + _fit_bm_clean( + dft_data[f"V/atom_{phase}"].values, + dft_data[f"Delta_{phase}_E"].values, + ) + for phase in phases + ] + model_fits = [ + _fit_bm_clean(model_data["V/atom"].values, model_data[f"{phase}_E"].values) + for phase in phases + ] + + if any(fit is None for fit in dft_fits + model_fits): + return np.nan, np.nan + + ref_volumes = dft_data[f"V/atom_{phases[0]}"].values + ref_volumes = ref_volumes[np.isfinite(ref_volumes)] + v_grid = np.linspace(ref_volumes.min(), ref_volumes.max(), 80) + + dft_deltas = np.vstack( + [ + birchmurnaghan(v_grid, *dft_fit) - birchmurnaghan(v_grid, *dft_fits[0]) + for dft_fit in dft_fits[1:] + ] + ) + model_deltas = np.vstack( + [ + birchmurnaghan(v_grid, *model_fit) - birchmurnaghan(v_grid, *model_fits[0]) + for model_fit in model_fits[1:] + ] + ) + + mae_ev = float(np.mean(np.abs(dft_deltas - model_deltas))) + correct_stability_pct = 100.0 * float(np.mean(np.all(model_deltas > 0, axis=0))) + + return 1000.0 * mae_ev, correct_stability_pct + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def eos_stats() -> dict[tuple[str, str], dict[str, float]]: + """ + Compute all three metrics for every model-element pair. + + Returns + ------- + dict[tuple[str, str], dict[str, float]] + Mapping of ``(model_name, element)`` to ``{"Delta", + "phase_diff_mae", "correct_stability"}``. + """ + OUT_PATH.mkdir(parents=True, exist_ok=True) + results: dict[tuple[str, str], dict[str, float]] = {} + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if not model_dir.exists(): + continue + + for element in ELEMENTS: + model_csv = model_dir / f"{element}_eos_results.csv" + dft_csv = DATA_PATH / f"{element}_eos_DFT.csv" + if not model_csv.exists() or not dft_csv.exists(): + continue + + model_data = pd.read_csv(model_csv) + dft_data = pd.read_csv(dft_csv, comment="#") + + phases = [ + col.split("_")[1] + for col in dft_data.columns + if col.startswith("Delta_") and col.endswith("_E") + ] + # BM fit for reference phase (Delta metric) + ref_phase = phases[0] # Assuming the first phase is the reference + ref_params = _fit_bm_clean( + model_data["V/atom"].values, model_data[f"{ref_phase}_E"].values + ) + dft_ref_params = _fit_bm_clean( + dft_data[f"V/atom_{ref_phase}"].values, + dft_data[f"Delta_{ref_phase}_E"].values, + ) + + if ref_params is not None and dft_ref_params is not None: + # eos_parameters: (E0, B0, BP, V0) + model_bm = { + "V0": ref_params[-1], + "B0": ref_params[-3], + "BP": ref_params[-2], + } + dft_bm = { + "V0": dft_ref_params[-1], + "B0": dft_ref_params[-3], + "BP": dft_ref_params[-2], + } + volumes = model_data["V/atom"] + delta, _, _ = calc_delta( + dft_bm, + model_bm, + useasymm=False, + vi=float(volumes.iloc[0]), + vf=float(volumes.iloc[-1]), + ) + else: + delta = np.nan + + phase_diff_mae, correct_stability = _phase_metrics_from_eos_mev( + dft_data, model_data + ) + + results[(model_name, element)] = { + "Δ": delta, + "Phase energy": phase_diff_mae, + "Phase stability": correct_stability, + } + + return results + + +def get_metric_per_element(model, eos_stats, metric_name): + """ + Get a dictionary of metric values for each element for a given model. + + Parameters + ---------- + model + The name of the model to extract metrics for. + eos_stats + The full EOS statistics dictionary containing all models and elements. + metric_name + The name of the metric to extract (e.g., "Δ", + "Phase energy", "Phase stability"). + + Returns + ------- + dict[str, float] + A dictionary mapping element symbols to their + corresponding metric values for the specified model. + """ + return { + el: eos_stats.get((model, el), {}).get(metric_name, np.nan) for el in ELEMENTS + } + + +@pytest.fixture +def periodic_tables( + eos_stats: dict[tuple[str, str], dict[str, float]], +) -> None: + """ + Write per-model periodic-table heatmaps for each EOS metric. + + Parameters + ---------- + eos_stats + Per-(model, element) metric values. + """ + file_suffixes = { + "Δ": "delta_periodic_table", + "Phase energy": "phase_energy_periodic_table", + "Phase stability": "phase_stability_periodic_table", + } + for model in MODELS: + for metric_name in ["Δ", "Phase energy", "Phase stability"]: + title = ( + f"{metric_name} (meV/atom)" + if metric_name != "Phase stability" + else f"{metric_name} (%)" + ) + colorbar_title = title + file_suffix = file_suffixes[metric_name] + values = get_metric_per_element(model, eos_stats, metric_name) + plot_periodic_table( + title=f"{title} - {model}", + colorbar_title=colorbar_title, + filename=str(OUT_PATH / model / f"{file_suffix}.json"), + )(lambda v=values: v)() + + +@pytest.fixture +def delta( + eos_stats: dict[tuple[str, str], dict[str, float]], +) -> dict[str, float]: + """ + Mean Delta (meV/atom) across elements for each model. + + Parameters + ---------- + eos_stats + Per-(model, element) metric values. + + Returns + ------- + dict[str, float] + Mean Delta per model. + """ + results: dict[str, float] = {} + for model_name in MODELS: + values = [ + eos_stats[(model_name, el)]["Δ"] + for el in ELEMENTS + if (model_name, el) in eos_stats + ] + results[model_name] = float(np.nanmean(values)) if values else None + return results + + +@pytest.fixture +def phase_diff_eos_mae( + eos_stats: dict[tuple[str, str], dict[str, float]], +) -> dict[str, float]: + """ + Mean PhaseDiffEOS MAE (meV/atom) across elements for each model. + + Parameters + ---------- + eos_stats + Per-(model, element) metric values. + + Returns + ------- + dict[str, float] + Mean PhaseDiffEOS_MAE_meV per model. + """ + results: dict[str, float] = {} + for model_name in MODELS: + values = [ + eos_stats[(model_name, el)]["Phase energy"] + for el in ELEMENTS + if (model_name, el) in eos_stats + ] + results[model_name] = float(np.nanmean(values)) if values else None + return results + + +@pytest.fixture +def correct_stability( + eos_stats: dict[tuple[str, str], dict[str, float]], +) -> dict[str, float]: + """ + Mean CorrectStability (%) across elements for each model. + + Parameters + ---------- + eos_stats + Per-(model, element) metric values. + + Returns + ------- + dict[str, float] + Mean Phase stability per model. + """ + results: dict[str, float] = {} + for model_name in MODELS: + values = [ + eos_stats[(model_name, el)]["Phase stability"] + for el in ELEMENTS + if (model_name, el) in eos_stats + ] + results[model_name] = float(np.nanmean(values)) if values else None + return results + + +@pytest.fixture +@build_table( + filename=OUT_PATH / "eos_metrics_table.json", + metric_tooltips=DEFAULT_TOOLTIPS, + thresholds=DEFAULT_THRESHOLDS, + weights=DEFAULT_WEIGHTS, +) +def metrics( + delta: dict[str, float], + phase_diff_eos_mae: dict[str, float], + correct_stability: dict[str, float], +) -> dict[str, dict]: + """ + All EOS benchmark metrics. + + Parameters + ---------- + delta + Mean Delta per model (meV/atom). + phase_diff_eos_mae + Mean PhaseDiffEOS MAE per model (meV/atom). + correct_stability + Mean Phase stability per model (%). + + Returns + ------- + dict[str, dict] + Mapping of metric name to per-model value dicts. + """ + return { + "Δ": delta, + "Phase energy": phase_diff_eos_mae, + "Phase stability": correct_stability, + } + + +def test_equation_of_state( + metrics: dict[str, dict], + periodic_tables: None, +) -> None: + """ + Run EOS benchmark analysis. + + Parameters + ---------- + metrics + All EOS benchmark metric values. + periodic_tables + Per-model periodic-table heatmaps (side-effect only). + """ + return diff --git a/ml_peg/analysis/bulk_crystal/equation_of_state/metrics.yml b/ml_peg/analysis/bulk_crystal/equation_of_state/metrics.yml new file mode 100644 index 000000000..5f22e6139 --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/equation_of_state/metrics.yml @@ -0,0 +1,19 @@ +metrics: + Δ: + good: 0.0 + bad: 200.0 + unit: meV/atom + tooltip: RMS energy difference per atom between the two EOS curves averaged over the sampled volume range. + level_of_theory: PBE + Phase energy: + good: 0.0 + bad: 200.0 + unit: meV/atom + tooltip: Mean absolute error of the phase energy differences relative to the reference phase between model and DFT. + level_of_theory: PBE + Phase stability: + good: 100.0 + bad: 0.0 + unit: '%' + tooltip: Percentage of correctly predicted stable phases. + level_of_theory: PBE diff --git a/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py new file mode 100644 index 000000000..db34eb592 --- /dev/null +++ b/ml_peg/app/bulk_crystal/equation_of_state/app_equation_of_state.py @@ -0,0 +1,205 @@ +"""Run equation of state benchmark app.""" + +from __future__ import annotations + +from pathlib import Path + +from dash import ALL, Dash, Input, Output, callback, callback_context +from dash.dcc import Graph +from dash.exceptions import PreventUpdate +from dash.html import Div +import pandas as pd +from plotly.colors import qualitative +import plotly.graph_objects as go +from plotly.io import read_json + +from ml_peg.app import APP_ROOT +from ml_peg.app.base_app import BaseApp +from ml_peg.app.utils.build_callbacks import plot_from_table_cell +from ml_peg.calcs import CALCS_ROOT +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +MODELS = get_model_names(current_models) +BENCHMARK_NAME = "Equation of State (metals)" +DOCS_URL = ( + "https://ddmms.github.io/ml-peg/user_guide/" + "benchmarks/bulk_crystal.html#equation-of-state" +) +DATA_PATH = APP_ROOT / "data" / "bulk_crystal" / "equation_of_state" +CALC_PATH = CALCS_ROOT / "bulk_crystal" / "equation_of_state" / "outputs" +INPUT_PATH = Path(__file__).parents[4] / "inputs" / "bulk_crystal" / "equation_of_state" +_PT_TYPE = "eos-periodic-table" +_EOS_CURVE_ID = f"{BENCHMARK_NAME}-eos-curve" +_METRICS = [ + ("\u0394", "delta_periodic_table"), + ("Phase energy", "phase_energy_periodic_table"), + ("Phase stability", "phase_stability_periodic_table"), +] + + +def _make_eos_figure(model: str, element: str) -> go.Figure | None: + """ + Create an equation of state figure for a given model and element. + + Parameters + ---------- + model : str + The model name. + element : str + The element name. + + Returns + ------- + go.Figure | None + The equation of state figure or None if the data is not available. + """ + model_csv = CALC_PATH / model / f"{element}_eos_results.csv" + dft_csv = INPUT_PATH / f"{element}_eos_DFT.csv" + if not model_csv.exists() or not dft_csv.exists(): + return None + model_data = pd.read_csv(model_csv) + dft_data = pd.read_csv(dft_csv, comment="#") + phases = [ + col.split("_")[1] + for col in dft_data.columns + if col.startswith("Delta_") and col.endswith("_E") + ] + colours = qualitative.D3 + fig = go.Figure() + for i, phase in enumerate(phases): + colour = colours[i % len(colours)] + dft_v = dft_data[f"V/atom_{phase}"].dropna() + dft_e = dft_data[f"Delta_{phase}_E"].loc[dft_v.index] + fig.add_trace( + go.Scatter( + x=dft_v, + y=dft_e, + mode="markers", + name=f"DFT {phase}", + marker={"symbol": "x", "color": colour, "size": 8}, + ) + ) + model_v = model_data["V/atom"] + model_delta_e = model_data[f"{phase}_E"] - model_data[f"{phases[0]}_E"].min() + fig.add_trace( + go.Scatter( + x=model_v, + y=model_delta_e, + mode="lines", + name=f"{model} {phase}", + line={"color": colour}, + ) + ) + fig.update_layout( + title=f"EOS - {element} ({model})", + xaxis_title="Volume per atom (\u00c5\u00b3)", + yaxis_title="Energy per atom (eV)", + ) + return fig + + +class EquationOfStateApp(BaseApp): + """Equation of State benchmark app layout and callbacks.""" + + def register_callbacks(self) -> None: + """Register callbacks to app.""" + cell_to_plot = {} + for model in MODELS: + plots = {} + for column_id, file_suffix in _METRICS: + path = DATA_PATH / model / f"{file_suffix}.json" + if not path.exists(): + continue + plots[column_id] = Graph( + id={ + "type": _PT_TYPE, + "model": model, + "metric": file_suffix, + }, + figure=read_json(path), + ) + if plots: + cell_to_plot[model] = plots + + plot_from_table_cell( + table_id=self.table_id, + plot_id=f"{BENCHMARK_NAME}-figure-placeholder", + cell_to_plot=cell_to_plot, + ) + + @callback( + Output(_EOS_CURVE_ID, "children"), + Input( + {"type": _PT_TYPE, "model": ALL, "metric": ALL}, + "clickData", + ), + prevent_initial_call=True, + ) + def show_eos_curve(_): + """ + Show the equation of state curve for the clicked element and model. + + Parameters + ---------- + _ : Any + The click data from the periodic table graph. + The actual value is not used, but the callback + context is used to determine which cell was clicked. + + Returns + ------- + Div + The div containing the equation of state figure. + """ + ctx = callback_context + triggered_id = ctx.triggered_id + if not isinstance(triggered_id, dict): + raise PreventUpdate + click_data = ctx.triggered[0]["value"] + if not click_data: + raise PreventUpdate + points = click_data.get("points", []) + if not points: + raise PreventUpdate + text = points[0].get("text", "") + element = text.split("
")[0].strip() + if not element or len(element) > 3: + raise PreventUpdate + model = triggered_id["model"] + fig = _make_eos_figure(model, element) + if fig is None: + return Div(f"No data for {element} / {model}.") + return Div(Graph(figure=fig)) + + +def get_app() -> EquationOfStateApp: + """ + Get equation of state benchmark app layout and callback registration. + + Returns + ------- + EquationOfStateApp + Benchmark layout and callback registration. + """ + return EquationOfStateApp( + name=BENCHMARK_NAME, + description=( + "Equation of state curves and phase stability for BCC metals " + "(W, Mo, Nb), benchmarked against PBE reference data." + ), + docs_url=DOCS_URL, + table_path=DATA_PATH / "eos_metrics_table.json", + extra_components=[ + Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), + Div(id=_EOS_CURVE_ID), + ], + ) + + +if __name__ == "__main__": + full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent) + equation_of_state_app = get_app() + full_app.layout = equation_of_state_app.layout + equation_of_state_app.register_callbacks() + full_app.run(port=8054, debug=True) diff --git a/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py new file mode 100644 index 000000000..ccf7b5c11 --- /dev/null +++ b/ml_peg/calcs/bulk_crystal/equation_of_state/calc_equation_of_state.py @@ -0,0 +1,152 @@ +"""Run calculations for EOS tests.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from ase.lattice.cubic import BodyCenteredCubic, FaceCenteredCubic, SimpleCubicFactory +import numpy as np +import pandas as pd +import pytest + +from ml_peg.models.get_models import load_models +from ml_peg.models.models import current_models + +MODELS = load_models(current_models) + +DATA_PATH = Path(__file__).parent / "../../../../inputs/bulk_crystal/equation_of_state/" +OUT_PATH = Path(__file__).parent / "outputs" + + +class A15Factory(SimpleCubicFactory): + """A factory for creating A15 lattices.""" + + xtal_name = "A15" + bravais_basis = [ + [0, 0, 0], + [0.5, 0.5, 0.5], + [0.5, 0.25, 0.0], + [0.5, 0.75, 0.0], + [0.0, 0.5, 0.25], + [0.0, 0.5, 0.75], + [0.25, 0.0, 0.5], + [0.75, 0.0, 0.5], + ] + + +A15 = A15Factory() + +lattices = {"BCC": BodyCenteredCubic, "FCC": FaceCenteredCubic, "A15": A15} + + +def equation_of_state( + calc, + lattice, + volumes_per_atoms, + symbol="W", + size=(2, 2, 2), +): + """ + Compute the equation of state for a given element and lattice. + + Parameters + ---------- + calc + ASE calculator to use for energy calculations. + lattice + ASE lattice class to use for structure generation. + volumes_per_atoms + Array of volumes per atom to sample for the EOS curve. + symbol + Chemical symbol of the element to use for structure generation. + size + Size of the supercell to generate for each volume per atom. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Lattice constants (A) and energies (eV/atom) arrays. + """ + # dummy call to have calc_num_atoms available + lattice(symbol="W", latticeconstant=3.16) + lattice_constants = (volumes_per_atoms * lattice.calc_num_atoms()) ** (1 / 3) + + structures = [ + lattice(latticeconstant=lc, size=size, symbol=symbol) + for lc in lattice_constants + ] + for structure in structures: + structure.calc = calc + + energies = [ + structure.get_potential_energy() / len(structure) for structure in structures + ] + + return np.array(lattice_constants), np.array(energies) + + +@pytest.mark.parametrize("mlip", MODELS.items()) +def test_equation_of_state(mlip: tuple[str, Any]) -> None: + """ + Test equation of state calculation for three BCC metals. + + Parameters + ---------- + mlip + Tuple of (model_name, model) as provided by pytest parametrize. + """ + model_name, model = mlip + calc = model.get_calculator() + + fns = list(DATA_PATH.glob("*DFT*")) + + for fn in fns: + element = fn.name.split("_")[0] + print(f"Starting EOS calculations for {element} with model {model_name}") + + dft_data = pd.read_csv(fn, comment="#") + + volumes_per_atoms = np.linspace( + np.round(dft_data[dft_data.columns[0]].min() * 0.95), + np.round(dft_data[dft_data.columns[0]].max() * 1.05), + 50, + endpoint=False, + ) + results = {"V/atom": volumes_per_atoms} + + phases = [col.split("_")[1] for col in dft_data.columns if "Delta" in col] + + for phase in phases: + assert phase in lattices, f"Lattice {phase} not implemented for EOS test." + lattice = lattices[phase] + """' + start_time = datetime.now() + print(f"Start time for {phase} @ {model_name}: {start_time}") + """ + lattice_constants, energies = equation_of_state( + calc, + lattice, + volumes_per_atoms, + symbol=element, + ) + """ + end_time = datetime.now() + duration = end_time - start_time + hours, remainder = divmod(duration.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + print(f"End time for {phase} @ {model_name}: {end_time}") + print( + f"Duration for {phase} @ {model_name}: " + f"{hours} hours {minutes} minutes {seconds} seconds" + ) + print(duration) + """ + results[f"{phase}_a"] = lattice_constants + results[f"{phase}_E"] = energies + + write_dir = OUT_PATH / model_name + df = pd.DataFrame(results) + output_file = write_dir / f"{element}_eos_results.csv" + write_dir.mkdir(parents=True, exist_ok=True) + df.to_csv(output_file, index=False)