diff --git a/inftools/tistools/plot_ens.py b/inftools/tistools/plot_ens.py index 6e7aa20..4974b6b 100644 --- a/inftools/tistools/plot_ens.py +++ b/inftools/tistools/plot_ens.py @@ -1,15 +1,18 @@ -from typing import Annotated +from typing import Annotated as Atd import typer def plot_ens( - toml: Annotated[str, typer.Option("-toml")] = "restart.toml", - skip: Annotated[bool, typer.Option("-skip", help="skip initial load paths")] = False, - save: Annotated[ str, typer.Option("-save", help="save with scienceplots.") ] = "no", - pp: Annotated[ bool, typer.Option("-pp", help="partial paths version") ] = False, - cap: Annotated[ int, typer.Option("-cap", help="max paths plotted per ens") ] = 100, - time: Annotated[float, typer.Option("-time", help="divide dt by an amount") ] = 1, + toml: Atd[str, typer.Option("-toml")] = "restart.toml", + data: Atd[str, typer.Option("-data")] = "", + skip: Atd[bool, typer.Option("-skip", help="skip initial load paths")] = False, + save: Atd[ str, typer.Option("-save", help="save with scienceplots.") ] = "no", + pp: Atd[ bool, typer.Option("-pp", help="partial paths version") ] = False, + cap: Atd[ int, typer.Option("-cap", help="max paths plotted per ens") ] = 100, + time: Atd[float, typer.Option("-time", help="divide dt by an amount") ] = 1, + load: Atd[str, typer.Option("-load", + help = "the path directory, reads load_dir from toml if not given") ] = "", ): """Plot sampled ensemble paths with interfaces""" import os @@ -24,11 +27,20 @@ def plot_ens( plt.figure(figsize=(14, 10)) # Read toml info - with open("restart.toml", "rb") as toml_file: + with open(toml, "rb") as toml_file: toml = tomli.load(toml_file) intf = toml["simulation"]["interfaces"] - datafile = toml["output"]["data_file"] - load_dir = toml["simulation"]["load_dir"] + if not toml["output"].get("data_file", False) and not data: + exit("Supply a infretis_data.txt file with -data") + elif data: + print(f"Using {data}") + datafile = data + else: + datafile = toml["output"]["data_file"] + if not load: + load_dir = toml["simulation"]["load_dir"] + else: + load_dir = load plt.title("intfs: " + " ".join([str(i) for i in intf])) plt.axhline(intf[0], ls="--", color="k", alpha=0.5)