-
Notifications
You must be signed in to change notification settings - Fork 2
Function To Cast InferenceData Into tidy_draws
Format
#36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 85 commits
3865001
763355e
44e7fe2
31c7b72
9a87902
c632ae8
123ad51
a3c2d17
f44a6ee
cb883e3
df922d4
21968be
7dcd7d3
718ba85
7394d4d
4a77d50
4c18634
9be91b4
0884c3d
0dd9616
2a0f2b9
c746e3b
e4e8e0f
c5ed832
e9e2aab
9a27187
cc20975
4a8a0ae
0c4301a
814f68c
beba8f3
b12ffc2
d85f49c
85c0dbf
786d2e4
4c4e6dd
e949796
ab0beb6
27e7ada
b06f6e6
42acd05
2ed4c29
0ad7c24
9452acd
325aac4
462d57b
dae049f
5e4cd63
5c28831
d246a44
496c65c
2c34af7
dba6e7a
e1cdbb9
9884aa9
90f5484
683a7bf
3547ab6
2999ec3
04d6a50
cdbe464
987c273
8361c36
41a8136
bb55bf4
1cdef9d
519c048
1019ac0
92c34bd
b3f6367
84bb99e
1a94da4
7a737f6
740f9c8
b3544e5
3c9609c
ca7c340
9b2cdc9
84ec6e1
07e83a9
238c80c
27436cd
70219ad
60e0b5a
61849e0
ba2847e
15d920b
0604de1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
AFg6K7h4fhy2 marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
|
||
# jax | ||
@software{jax2018github, | ||
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, | ||
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, | ||
url = {http://github.com/jax-ml/jax}, | ||
version = {0.3.13}, | ||
year = {2018}, | ||
} | ||
|
||
# numpyro | ||
@article{phan2019composable, | ||
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro}, | ||
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin}, | ||
journal={arXiv preprint arXiv:1912.11554}, | ||
year={2019} | ||
} |
AFg6K7h4fhy2 marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
Contains functions for interfacing between | ||
the tidy-verse and arviz, which includes | ||
the conversion of idata objects (and hence | ||
their groups) in tidy-usable objects. | ||
""" | ||
|
||
import re | ||
|
||
import arviz as az | ||
import polars as pl | ||
import polars.selectors as cs | ||
|
||
|
||
def convert_inference_data_to_tidydraws( | ||
idata: az.InferenceData, groups: list[str] = None | ||
) -> dict[str, pl.DataFrame]: | ||
""" | ||
Creates a dictionary of polars dataframes | ||
from the groups of an arviz InferenceData | ||
object for use with the tidybayes API. | ||
|
||
Parameters | ||
---------- | ||
idata : az.InferenceData | ||
An InferenceData object. | ||
groups : list[str] | ||
A list of groups to transform to | ||
tidy draws format. Defaults to all | ||
groups in the InferenceData. | ||
|
||
Returns | ||
------- | ||
dict[str, pl.DataFrame] | ||
A dictionary of groups from the idata | ||
for use with the tidybayes API. | ||
""" | ||
available_groups = list(idata.groups()) | ||
if groups is None: | ||
groups = available_groups | ||
else: | ||
invalid_groups = [ | ||
group for group in groups if group not in available_groups | ||
] | ||
if invalid_groups: | ||
raise ValueError( | ||
f"Requested groups {invalid_groups} not found" | ||
" in this InferenceData object." | ||
f" Available groups: {available_groups}" | ||
) | ||
|
||
idata_df = pl.DataFrame(idata.to_dataframe()) | ||
|
||
tidy_dfs = { | ||
group: ( | ||
idata_df.select("chain", "draw", cs.starts_with(f"('{group}',")) | ||
.rename( | ||
{ | ||
col: re.search(r",\s*'?(.+?)'?\)", col).group(1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feels like this would be better handled by providing a lambda rename mapping to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hadn't thought of this. In the back of mind there are reservations I've seen from others that I haven't verified regarding about lambda mappings but they're not sufficient to not act here. Polars There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed the code to: .rename(
lambda col: re.search(r",\s*'?(.+?)'?\)", col).group(1)
if col.startswith(f"('{group}',")
else col
) got
then changed to: .rename(
lambda col, group=group: re.search(
r",\s*'?(.+?)'?\)", col
).group(1)
if col.startswith(f"('{group}',")
else col
) which works w/ tests & linting. Commenting above for my future self. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re: #36 (comment) Will do. I agree, |
||
for col in idata_df.columns | ||
if col.startswith(f"('{group}',") | ||
} | ||
) | ||
# draw in arviz is iteration in tidybayes | ||
.rename({"draw": ".iteration", "chain": ".chain"}) | ||
.unpivot( | ||
index=[".chain", ".iteration"], | ||
variable_name="variable", | ||
value_name="value", | ||
) | ||
.with_columns( | ||
pl.col("variable").str.replace(r"\[.*\]", "").alias("variable") | ||
) | ||
.with_columns(pl.col(".iteration") + 1, pl.col(".chain") + 1) | ||
.pivot( | ||
values="value", | ||
index=[".chain", ".iteration"], | ||
columns="variable", | ||
aggregate_function=None, | ||
) | ||
.sort([".chain", ".iteration"]) | ||
.with_row_index(name=".draw", offset=1) | ||
) | ||
for group in groups | ||
} | ||
return tidy_dfs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
--- | ||
title: Add Dates To Pyrenew Idata And Use In Tidy-Verse | ||
format: gfm | ||
engine: jupyter | ||
--- | ||
|
||
_The following notebook illustrates the addition of dates to an external `idata` object before demonstrating the tidy-usable capabilities in R._ | ||
|
||
__Load In Packages And External Pyrenew InferenceData Object__ | ||
|
||
```{python} | ||
#| echo: true | ||
|
||
import forecasttools | ||
import arviz as az | ||
import xarray as xr | ||
import os | ||
import polars as pl | ||
import subprocess | ||
import tempfile | ||
from datetime import date, timedelta | ||
|
||
xr.set_options(display_expand_data=False, display_expand_attrs=False) | ||
|
||
|
||
pyrenew_idata_path = "../assets/external/inference_data_1.nc" | ||
pyrenew_idata = az.from_netcdf(pyrenew_idata_path) | ||
pyrenew_idata | ||
``` | ||
|
||
__Define Groups To Save And Convert__ | ||
|
||
|
||
```{python} | ||
#| echo: true | ||
|
||
pyrenew_groups = ["posterior_predictive"] | ||
tidy_usable_groups = forecasttools.convert_inference_data_to_tidydraws( | ||
idata=pyrenew_idata, | ||
groups=pyrenew_groups | ||
) | ||
|
||
# show output | ||
tidy_usable_groups | ||
``` | ||
|
||
```{python} | ||
# TO DELETE | ||
nested_tidy_dfs = { | ||
group: { | ||
var: df.select([".chain", ".iteration", ".draw", var]) | ||
for var in [col for col in df.columns if col not in [".chain", ".iteration", ".draw"]] | ||
} | ||
for group, df in tidy_usable_groups.items() | ||
} | ||
|
||
nested_tidy_dfs | ||
``` | ||
|
||
__Demonstrate Adding Time To Pyrenew InferenceData__ | ||
|
||
```{python} | ||
#| echo: true | ||
|
||
start_time_as_dt = date(2022, 8, 1) # arbitrary | ||
|
||
pyrenew_target_var = pyrenew_idata["posterior_predictive"]["observed_hospital_admissions"] | ||
print(pyrenew_target_var) | ||
|
||
pyrenew_var_w_dates = forecasttools.generate_time_range_for_dim( | ||
start_time_as_dt=start_time_as_dt, | ||
variable_data=pyrenew_target_var, | ||
dimension="observed_hospital_admissions_dim_0", | ||
time_step=timedelta(days=1), | ||
) | ||
print(pyrenew_var_w_dates[:5], type(pyrenew_var_w_dates[0])) | ||
``` | ||
|
||
__Add Dates To Pyrenew InferenceData__ | ||
|
||
```{python} | ||
#| echo: true | ||
|
||
pyrenew_idata_w_dates = forecasttools.add_time_coords_to_idata_dimension( | ||
idata=pyrenew_idata, | ||
group="posterior_predictive", | ||
variable="observed_hospital_admissions", | ||
dimension="observed_hospital_admissions_dim_0", | ||
start_date_iso=start_time_as_dt, | ||
time_step=timedelta(days=1), | ||
) | ||
|
||
print(pyrenew_idata_w_dates["posterior_predictive"]["observed_hospital_admissions"]["observed_hospital_admissions_dim_0"]) | ||
pyrenew_idata_w_dates | ||
``` | ||
|
||
__Again Convert The Dated Pyrenew InferenceData To Tidy-Usable__ | ||
|
||
|
||
```{python} | ||
|
||
pyrenew_groups = ["posterior_predictive"] | ||
tidy_usable_groups_w_dates = forecasttools.convert_inference_data_to_tidydraws( | ||
idata=pyrenew_idata_w_dates, | ||
groups=pyrenew_groups | ||
) | ||
tidy_usable_groups_w_dates | ||
``` | ||
|
||
__Examine The Dataframe In The Tidyverse__ | ||
|
||
```{python} | ||
def light_r_runner(r_code: str) -> None: | ||
""" | ||
Run R code from Python as a temp file. | ||
""" | ||
with tempfile.NamedTemporaryFile(suffix=".R", delete=False) as temp_r_file: | ||
temp_r_file.write(r_code.encode("utf-8")) | ||
temp_r_file_path = temp_r_file.name | ||
try: | ||
subprocess.run(["Rscript", temp_r_file_path], check=True) | ||
except subprocess.CalledProcessError as e: | ||
print(f"R script failed with error: {e}") | ||
finally: | ||
os.remove(temp_r_file_path) | ||
|
||
#for key, tidy_df in tidy_usable_groups_w_dates.items(): | ||
# file_name = f"{key}.csv" | ||
# if not os.path.exists(file_name): | ||
# tidy_df.write_csv(file_name) | ||
# print(f"Saved {file_name}") | ||
|
||
for group, var in nested_tidy_dfs.items(): | ||
for var_name, tidy_data in var.items(): | ||
file_name = f"{var_name}.csv" | ||
#if not os.path.exists(file_name): | ||
tidy_data.write_csv(file_name) | ||
print(f"Saved {file_name}") | ||
|
||
|
||
r_code_to_verify_tibble = """ | ||
library(magrittr) | ||
library(tidyverse) | ||
library(tidybayes) | ||
|
||
csv_files <- c("rt.csv") | ||
|
||
for (csv_file in csv_files) { | ||
tibble_data <- read_csv(csv_file) | ||
|
||
print(paste("Tibble from", csv_file)) | ||
print(tibble_data) | ||
|
||
tidy_data <- tibble_data %>% | ||
tidybayes::tidy_draws() | ||
print(tidy_data) | ||
} | ||
""" | ||
light_r_runner(r_code_to_verify_tibble) | ||
``` |
Uh oh!
There was an error while loading. Please reload this page.