diff --git a/forestplot/dataframe_utils.py b/forestplot/dataframe_utils.py index 42907b6..8ff0118 100644 --- a/forestplot/dataframe_utils.py +++ b/forestplot/dataframe_utils.py @@ -1,9 +1,13 @@ """Holds functions to check prepare dataframe for plotting.""" +import os +from pathlib import Path from typing import Any, Optional, Union import numpy as np import pandas as pd +offline = os.getenv("FORESTPLOT_OFFLINE") + def insert_groups( dataframe: pd.core.frame.DataFrame, groupvar: str, varlabel: str @@ -120,19 +124,26 @@ def insert_empty_row(dataframe: pd.core.frame.DataFrame) -> pd.core.frame.DataFr return dataframe -def load_data(name: str, **param_dict: Optional[Any]) -> pd.core.frame.DataFrame: +def load_data( + name: str, + data_path: Union[Path, str] = Path("./examples/data/"), + **param_dict: Optional[Any], +) -> pd.core.frame.DataFrame: """ Load example dataset for quickstart. Example data available now: - mortality - The source of these data will be from: https://github.com/LSYS/forestplot/tree/main/examples/data. + The source of the data will be in forestplot/examples/data if files exist there + else from https://github.com/LSYS/forestplot/tree/main/examples/data. Parameters ---------- name (str) Name of the example data set. + data_path (Path) [Optional] + Directory containing local copies of csv data Returns ------- @@ -141,10 +152,16 @@ def load_data(name: str, **param_dict: Optional[Any]) -> pd.core.frame.DataFrame available_data = ["mortality", "sleep", "sleep-untruncated"] name = name.lower().strip() if name in available_data: - url = ( - f"https://raw.githubusercontent.com/lsys/forestplot/main/examples/data/{name}.csv" - ) - df = pd.read_csv(url, **param_dict) + data_path = Path(data_path) / f"{name}.csv" + if data_path.is_file(): + df = pd.read_csv(data_path, **param_dict) + elif offline: + raise AssertionError( + f"{data_path} not found. Working offline (FORESTPLOT_OFFLINE={offline})." + ) + else: + url = f"https://github.com/LSYS/forestplot/tree/main/examples/data/{name}.csv" + df = pd.read_csv(url, **param_dict) if name == "sleep": df["n"] = df["n"].astype("str") return df diff --git a/forestplot/graph_utils.py b/forestplot/graph_utils.py index a81ba09..749ed1d 100644 --- a/forestplot/graph_utils.py +++ b/forestplot/graph_utils.py @@ -9,6 +9,11 @@ warnings.filterwarnings("ignore") +def _get_pad(ax: Axes, **kwargs: Optional[Any]) -> float: + extrapad = kwargs.get("extrapad", 0.05) + return ax.get_xlim()[1] + extrapad * (ax.get_xlim()[1] - ax.get_xlim()[0]) + + def draw_ci( dataframe: pd.core.frame.DataFrame, estimate: str, @@ -233,8 +238,7 @@ def draw_pval_right( if pd.isna(yticklabel2): yticklabel2 = "" - extrapad = 0.05 - pad = ax.get_xlim()[1] * (1 + extrapad) + pad = _get_pad(ax, **kwargs) t = ax.text( x=pad, y=yticklabel1, @@ -330,8 +334,7 @@ def draw_yticklabel2( yticklabel1 = row["yticklabel"] yticklabel2 = row["yticklabel2"] - extrapad = 0.05 - pad = ax.get_xlim()[1] * (1 + extrapad) + pad = _get_pad(ax, **kwargs) if (ix == top_row_ix) and ( annoteheaders is not None or right_annoteheaders is not None ): @@ -706,8 +709,7 @@ def draw_tablelines( [x0, x1], [nrows - 1.45, nrows - 1.45], color="0.5", linewidth=lower_lw, clip_on=False ) if (right_annoteheaders is not None) or (pval is not None): - extrapad = kwargs.get("extrapad", 0.05) - x0 = ax.get_xlim()[1] * (1 + extrapad) + x0 = _get_pad(ax, **kwargs) plt.plot( [x0, righttext_width], [nrows - 0.4, nrows - 0.4], diff --git a/tests/test_plot.py b/tests/test_plot.py index 39bbc2d..94304d3 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,12 +1,17 @@ #!/usr/bin/env python # coding: utf-8 +from pathlib import Path + import pandas as pd from matplotlib.pyplot import Axes from forestplot import forestplot dataname = "sleep" -data = f"https://raw.githubusercontent.com/lsys/pyforestplot/main/examples/data/{dataname}.csv" +data = Path(f"./examples/data/{dataname}.csv") +if not data.is_file(): + data = f"https://raw.githubusercontent.com/lsys/pyforestplot/main/examples/data/{dataname}.csv" + df = pd.read_csv(data).assign(n=lambda df: df["n"].map(str)) # fmt: off