Skip to content

Function To Cast InferenceData Into tidy_draws Format #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 88 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
3865001
initial commit for this PR; begin skeleton experimentation file
AFg6K7h4fhy2 Oct 28, 2024
763355e
some unfinished experimentation code; priority status change high to …
AFg6K7h4fhy2 Oct 28, 2024
44e7fe2
add first semi-failed attempt at converting entire idata object to ti…
AFg6K7h4fhy2 Oct 30, 2024
31c7b72
add attempt at option 2
AFg6K7h4fhy2 Oct 30, 2024
9a87902
slightly modify spread draws example
AFg6K7h4fhy2 Nov 4, 2024
c632ae8
more minor changes to tidy draws notebook
AFg6K7h4fhy2 Nov 4, 2024
123ad51
light edits during DHM convo
AFg6K7h4fhy2 Nov 7, 2024
a3c2d17
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Nov 25, 2024
f44a6ee
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Nov 25, 2024
cb883e3
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Nov 25, 2024
df922d4
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Nov 26, 2024
21968be
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Dec 5, 2024
7dcd7d3
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Dec 9, 2024
718ba85
a DB conversion attempt
AFg6K7h4fhy2 Dec 12, 2024
7394d4d
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Dec 16, 2024
4a77d50
begin references file; create external program folder
AFg6K7h4fhy2 Dec 16, 2024
4c18634
add csv ignoring when converting to csv
AFg6K7h4fhy2 Dec 16, 2024
9be91b4
further edits; problem with dates
AFg6K7h4fhy2 Dec 16, 2024
0884c3d
partial attempt edits
AFg6K7h4fhy2 Dec 16, 2024
0dd9616
minor reconsideration of unpivot pathways
AFg6K7h4fhy2 Dec 16, 2024
2a0f2b9
minor update, tidydraws replicator
AFg6K7h4fhy2 Dec 18, 2024
c746e3b
minor cleaning of existing notebook code
AFg6K7h4fhy2 Dec 18, 2024
e4e8e0f
some of the reimplementation attempt
AFg6K7h4fhy2 Jan 16, 2025
c5ed832
testing of conversion between date intervals and dataframe
AFg6K7h4fhy2 Jan 17, 2025
e9e2aab
small further edit
AFg6K7h4fhy2 Jan 17, 2025
9a27187
add pyrenew inference file; some code mods
AFg6K7h4fhy2 Jan 17, 2025
cc20975
remove extraneous notebooks
AFg6K7h4fhy2 Jan 21, 2025
4a8a0ae
remove more idata to tidy code
AFg6K7h4fhy2 Jan 21, 2025
0c4301a
create tidy draws dataframe dict
AFg6K7h4fhy2 Jan 21, 2025
814f68c
error comment
AFg6K7h4fhy2 Jan 21, 2025
beba8f3
remove earlier code version
AFg6K7h4fhy2 Jan 22, 2025
b12ffc2
begin port of notebook into codebase
AFg6K7h4fhy2 Jan 22, 2025
d85f49c
update docstring
AFg6K7h4fhy2 Jan 22, 2025
85c0dbf
add some of the notebook
AFg6K7h4fhy2 Jan 22, 2025
786d2e4
some minor notebook edits
AFg6K7h4fhy2 Jan 22, 2025
4c4e6dd
further refine notebook
AFg6K7h4fhy2 Jan 22, 2025
e949796
update func name for clarity
AFg6K7h4fhy2 Jan 28, 2025
ab0beb6
remove historical scripts
AFg6K7h4fhy2 Jan 28, 2025
27e7ada
tidybayes API
AFg6K7h4fhy2 Jan 28, 2025
b06f6e6
unfinished edits addressing several comments
AFg6K7h4fhy2 Jan 28, 2025
42acd05
docstring change
AFg6K7h4fhy2 Jan 28, 2025
2ed4c29
more corrective edits
AFg6K7h4fhy2 Jan 28, 2025
0ad7c24
another corrective edit
AFg6K7h4fhy2 Jan 28, 2025
9452acd
check for invalid groups
AFg6K7h4fhy2 Jan 29, 2025
325aac4
some extranous code removal
AFg6K7h4fhy2 Jan 30, 2025
462d57b
add partial, in-need-of-edits tests
AFg6K7h4fhy2 Jan 30, 2025
dae049f
comment test to prevent fail; add nesting of groups
AFg6K7h4fhy2 Jan 30, 2025
5e4cd63
comment test to prevent fail; add nesting of groups
AFg6K7h4fhy2 Jan 30, 2025
5c28831
small update to showcase script
AFg6K7h4fhy2 Jan 30, 2025
d246a44
uncomment failing tests; run pre-commit
AFg6K7h4fhy2 Feb 3, 2025
496c65c
attempt fix pre-commit issues
AFg6K7h4fhy2 Feb 3, 2025
2c34af7
pre-commit edits
AFg6K7h4fhy2 Feb 3, 2025
dba6e7a
more prudent usage of ruff; pre-commit fixes
AFg6K7h4fhy2 Feb 4, 2025
e1cdbb9
revert version edits from black error
AFg6K7h4fhy2 Feb 4, 2025
9884aa9
use selectors; finally, nice to figure that out
AFg6K7h4fhy2 Feb 4, 2025
90f5484
fix some conflicts with daily to epiweekly
AFg6K7h4fhy2 Feb 4, 2025
683a7bf
Merge branch 'main' into 18-function-to-cast-inferencedata-into-tidy_…
AFg6K7h4fhy2 Feb 4, 2025
3547ab6
fix pre-commit errors
AFg6K7h4fhy2 Feb 4, 2025
2999ec3
add pivot to make life easier; not sure if to aggregate by first
AFg6K7h4fhy2 Feb 4, 2025
04d6a50
further debugging edits
AFg6K7h4fhy2 Feb 4, 2025
cdbe464
draws and iterations debugged
AFg6K7h4fhy2 Feb 4, 2025
987c273
revert versioning in test yaml
AFg6K7h4fhy2 Feb 4, 2025
8361c36
update location table; add united states data; update descriptions in…
AFg6K7h4fhy2 Feb 5, 2025
41a8136
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Feb 5, 2025
bb55bf4
remove extraneous united states parquet call
AFg6K7h4fhy2 Feb 5, 2025
1cdef9d
Update forecasttools/idata_to_tidy.py
AFg6K7h4fhy2 Feb 6, 2025
519c048
Update forecasttools/idata_to_tidy.py
AFg6K7h4fhy2 Feb 6, 2025
1019ac0
fix docstring; fix chain equation
AFg6K7h4fhy2 Feb 6, 2025
92c34bd
revert ensure listlike import
AFg6K7h4fhy2 Feb 6, 2025
b3f6367
remove tab ignoral
AFg6K7h4fhy2 Feb 6, 2025
84bb99e
switch from melt to pivot
AFg6K7h4fhy2 Feb 6, 2025
1a94da4
lightweight change to dev deps
AFg6K7h4fhy2 Feb 10, 2025
7a737f6
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Feb 10, 2025
740f9c8
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Feb 11, 2025
b3544e5
commented change; was examining results
AFg6K7h4fhy2 Feb 14, 2025
3c9609c
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Feb 14, 2025
ca7c340
have draw calculation take into account row count
AFg6K7h4fhy2 Mar 12, 2025
9b2cdc9
remove extraneous metaflow
AFg6K7h4fhy2 Mar 12, 2025
84ec6e1
some of the tests using simple idata class
AFg6K7h4fhy2 Mar 13, 2025
07e83a9
additional test
AFg6K7h4fhy2 Mar 13, 2025
238c80c
Update forecasttools/idata_to_tidy.py
AFg6K7h4fhy2 Mar 13, 2025
27436cd
revert to original aggregate function argument value
AFg6K7h4fhy2 Mar 13, 2025
70219ad
add base posterior predictive test for pyrenew idata
AFg6K7h4fhy2 Mar 14, 2025
60e0b5a
change test path; capture aggregate function error
AFg6K7h4fhy2 Mar 14, 2025
61849e0
change col search method from strip to re
AFg6K7h4fhy2 Mar 18, 2025
ba2847e
add lambda rename rather than dictionary comprehension
AFg6K7h4fhy2 Mar 18, 2025
15d920b
remove comment
AFg6K7h4fhy2 Mar 18, 2025
0604de1
update pre-commit config file to remove loop binding linting at reque…
AFg6K7h4fhy2 Mar 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,4 @@ DS_Store
_book/
_book
render_test_idata_general_time_representation_files
*.csv
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ repos:
################################################################################
# PYTHON
################################################################################
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.10.0
hooks:
- id: black
args: ["--line-length", "79"]
language_version: python3
# - repo: https://github.com/psf/black-pre-commit-mirror
# rev: 24.10.0
# hooks:
# - id: black
# args: ["--line-length", "79"]
# language_version: python3
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
Expand Down
24 changes: 24 additions & 0 deletions assets/external/idata_csv_to_tidy.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
arviz_split <- function(x) {
x %>%
select(-distribution) %>%
split(f = as.factor(x$distribution))
}

pyrenew_samples <-
read_csv(inference_data_path) %>%
rename_with(\(varname) str_remove_all(varname, "\\(|\\)|\\'|(, \\d+)")) |>
rename(
.chain = chain,
.iteration = draw
) |>
mutate(across(c(.chain, .iteration), \(x) as.integer(x + 1))) |>
mutate(
.draw = tidybayes:::draw_from_chain_and_iteration_(.chain, .iteration),
.after = .iteration
) |>
pivot_longer(-starts_with("."),
names_sep = ", ",
names_to = c("distribution", "name")
) |>
arviz_split() |>
map(\(x) pivot_wider(x, names_from = name) |> tidy_draws())
Binary file added assets/external/inference_data_1.nc
Binary file not shown.
85 changes: 85 additions & 0 deletions assets/external/inference_data_tidy.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
format_split_text <- function(x, concat_char = "|") {
group <- x[1]
non_group <- x[-1]
pre_bracket <- stringr::str_extract(non_group[1], "^.*(?=\\[)")
if (is.na(pre_bracket)) {
formatted_text <- glue::glue("{group}{concat_char}{non_group}")
} else {
bracket_contents <- non_group[-1] |>
stringr::str_replace_all("\\s", "_") |>
stringr::str_c(collapse = ",")
formatted_text <- glue::glue(
"{group}{concat_char}{pre_bracket}[{bracket_contents}]"
)
}
formatted_text
}

#' Convert InferenceData column names to tidy column names
#'
#' InferenceData column names for scalar variables are of the form
#' `"('group', 'var_name')"`, while column names for array variables are of the
#' form `"('group', 'var_name[i,j]', 'i_name', 'j_name')"`.
#' This function converts these column names to a format that is useful for
#' creating tidy_draws data frames.
#' `"('group', 'var_name')"` becomes `"group|var_name"`
#' `"('group', 'var_name[i,j]', 'i_name', 'j_name')"` becomes
#' `"group|var_name[i_name, j_name]"`
#'
#' @param column_names A character vector of InferenceData column names
#'
#' @return A character vector of tidy column names
#' @examples
#' forecasttools:::idata_names_to_tidy_names(c(
#' "('group', 'var_name')",
#' "group|var_name[i_name, j_name]"
#' ))
idata_names_to_tidy_names <- function(column_names) {
column_names |>
stringr::str_remove_all("^\\(|\\)$") |>
# remove opening and closing parentheses
stringr::str_split(", ") |>
purrr::map(\(x) stringr::str_remove_all(x, "^\\'|\\'$")) |>
# remove opening and closing quotes
purrr::map(\(x) stringr::str_remove_all(x, '\\"')) |> # remove double quotes
purrr::map_chr(format_split_text) # reformat groups and brackets
}

#' Convert InferenceData DataFrame to nested tibble of tidy_draws
#'
#' @param idata InferenceData DataFrame (the result of calling
#' arviz.InferenceData.to_dataframe in Python)
#'
#' @return A nested tibble, with columns group and data. Each element of data is
#' a tidy_draws data frame
#' @export

inferencedata_to_tidy_draws <- function(idata) {
idata |>
dplyr::rename(
.chain = chain,
.iteration = draw
) |>
dplyr::rename_with(idata_names_to_tidy_names,
.cols = -tidyselect::starts_with(".")
) |>
dplyr::mutate(dplyr::across(
c(.chain, .iteration),
\(x) as.integer(x + 1) # convert to 1-indexed
)) |>
dplyr::mutate(
.draw = tidybayes:::draw_from_chain_and_iteration_(.chain, .iteration),
.after = .iteration
) |>
tidyr::pivot_longer(-starts_with("."),
names_sep = "\\|",
names_to = c("group", "name")
) |>
dplyr::group_by(group) |>
tidyr::nest() |>
dplyr::mutate(data = purrr::map(data, \(x) {
tidyr::drop_na(x) |>
tidyr::pivot_wider(names_from = name) |>
tidybayes::tidy_draws()
}))
}
17 changes: 17 additions & 0 deletions assets/misc/references.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

# jax
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}

# numpyro
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
2 changes: 2 additions & 0 deletions forecasttools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import polars as pl

from forecasttools.daily_to_epiweekly import df_aggregate_to_epiweekly
from forecasttools.idata_to_tidy import convert_idata_forecast_to_tidydraws
from forecasttools.idata_w_dates_to_df import (
add_time_coords_to_idata_dimension,
add_time_coords_to_idata_dimensions,
Expand Down Expand Up @@ -95,4 +96,5 @@
"generate_time_range_for_dim",
"validate_iter_has_expected_types",
"ensure_listlike",
"convert_idata_forecast_to_tidydraws"
]
71 changes: 71 additions & 0 deletions forecasttools/idata_to_tidy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Contains functions for interfacing between
the tidy-verse and arviz, which includes
the conversion of idata objects (and hence
their groups) in tidy-usable objects.
"""

import re

import arviz as az
import polars as pl


def convert_idata_forecast_to_tidydraws(
idata: az.InferenceData,
groups: list[str]
) -> dict[str, pl.DataFrame]:
"""
Creates a dictionary of polars dataframes
from the groups of an arviz InferenceData
object that when converted to a csv(s)
and read in R is tidy-usable.

Parameters
----------
idata : az.InferenceData
An InferenceData object generated
from a numpyro forecast. Typically
has the groups observed_data and
posterior_predictive.
groups : list[str]
A list of groups belonging to the
idata object.

Returns
-------
dict[str, pl.DataFrame]
A dictionary of groups from the idata
convert to tidy-usable polars dataframe.
"""
tidy_dfs = {}
idata_df = idata.to_dataframe()
for group in groups:
group_columns = [
col for col in idata_df.columns
if isinstance(col, tuple) and col[0] == group
]
meta_columns = ["chain", "draw"]
group_df = idata_df[meta_columns + group_columns]
group_df.columns = [
col[1] if isinstance(col, tuple) else col
for col in group_df.columns
]
group_pols_df = pl.from_pandas(group_df)
value_columns = [col for col in group_pols_df.columns if col not in meta_columns]
group_pols_df = group_pols_df.melt(
id_vars=meta_columns,
value_vars=value_columns,
variable_name="variable",
value_name="value"
)
group_pols_df = group_pols_df.with_columns(
pl.col("variable").map_elements(
lambda x: re.sub(r"\[.*\]", "", x)).alias("variable")
)
group_pols_df = group_pols_df.with_columns(
((pl.col("draw") - 1) % group_pols_df["draw"].n_unique() + 1).alias(".iteration")
)
group_pols_df = group_pols_df.rename({"chain": ".chain", "draw": ".draw"})
tidy_dfs[group] = group_pols_df.select([".chain", ".draw", ".iteration", "variable", "value"])
return tidy_dfs
156 changes: 156 additions & 0 deletions notebooks/pyrenew_dates_to_tidy.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
---
title: Add Dates To Pyrenew Idata And Use In Tidy-Verse
format: gfm
engine: jupyter
---

_The following notebook illustrates the addition of dates to an external `idata` object before demonstrating the tidy-usable capabilities in R._

__Load In Packages And External Pyrenew InferenceData Object__

```{python}
#| echo: true

import forecasttools
import arviz as az
import xarray as xr
import os
import subprocess
import tempfile
from datetime import date, timedelta

xr.set_options(display_expand_data=False, display_expand_attrs=False)


pyrenew_idata_path = "../assets/external/inference_data_1.nc"
pyrenew_idata = az.from_netcdf(pyrenew_idata_path)
pyrenew_idata
```

__Define Groups To Save And Convert__

```{python}
#| echo: true

pyrenew_groups = ["posterior_predictive"]
tidy_usable_groups = forecasttools.convert_idata_forecast_to_tidydraws(
idata=pyrenew_idata,
groups=pyrenew_groups
)

# show variables
print(tidy_usable_groups["posterior_predictive"]["variable"].unique())

# show output
tidy_usable_groups
```

__Demonstrate Adding Time To Pyrenew InferenceData__

```{python}
#| echo: true

start_time_as_dt = date(2022, 8, 1) # arbitrary

pyrenew_target_var = pyrenew_idata["posterior_predictive"]["observed_hospital_admissions"]
print(pyrenew_target_var)

pyrenew_var_w_dates = forecasttools.generate_time_range_for_dim(
start_time_as_dt=start_time_as_dt,
variable_data=pyrenew_target_var,
dimension="observed_hospital_admissions_dim_0",
time_step=timedelta(days=1),
)
print(pyrenew_var_w_dates[:5], type(pyrenew_var_w_dates[0]))
```

__Add Dates To Pyrenew InferenceData__

```{python}
#| echo: true

pyrenew_idata_w_dates = forecasttools.add_time_coords_to_idata_dimension(
idata=pyrenew_idata,
group="posterior_predictive",
variable="observed_hospital_admissions",
dimension="observed_hospital_admissions_dim_0",
start_date_iso=start_time_as_dt,
time_step=timedelta(days=1),
)

print(pyrenew_idata_w_dates["posterior_predictive"]["observed_hospital_admissions"]["observed_hospital_admissions_dim_0"])
pyrenew_idata_w_dates
```

__Again Convert The Dated Pyrenew InferenceData To Tidy-Usable__


```{python}

pyrenew_groups = ["posterior_predictive"]
tidy_usable_groups_w_dates = forecasttools.convert_idata_forecast_to_tidydraws(
idata=pyrenew_idata_w_dates,
groups=pyrenew_groups
)
tidy_usable_groups_w_dates
```

__Examine The Dataframe In The Tidyverse__

```{python}
def light_r_runner(r_code: str) -> None:
"""
Run R code from Python as a temp file.
"""
with tempfile.NamedTemporaryFile(suffix=".R", delete=False) as temp_r_file:
temp_r_file.write(r_code.encode("utf-8"))
temp_r_file_path = temp_r_file.name
try:
subprocess.run(["Rscript", temp_r_file_path], check=True)
except subprocess.CalledProcessError as e:
print(f"R script failed with error: {e}")
finally:
os.remove(temp_r_file_path)


for key, tidy_df in tidy_usable_groups_w_dates.items():
file_name = f"{key}.csv"
if not os.path.exists(file_name):
tidy_df.write_csv(file_name)
print(f"Saved {file_name}")


r_code_to_verify_tibble = """
library(magrittr)
library(tidyverse)
library(tidybayes)

csv_files <- c("posterior_predictive.csv")

for (csv_file in csv_files) {
tibble_data <- read_csv(csv_file)

print(paste("Tibble from", csv_file))
print(tibble_data)

tidy_data <- tibble_data %>%
tidybayes::tidy_draws()
print(tidy_data)
}
"""
light_r_runner(r_code_to_verify_tibble)
```

The output of the last cell is:

```
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
Error: To use a data frame directly with `tidy_draws()`, it must already be a
tidy-format data frame of draws: it must have integer-like `.chain`
`.iteration`, and `.draw` columns with one row per draw.

The `.draw` column in the input data frame has more than one row per draw
(its values are not unique).
Execution halted
```
Loading
Loading