-
Notifications
You must be signed in to change notification settings - Fork 2
Function To Cast InferenceData Into tidy_draws
Format
#36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Function To Cast InferenceData Into tidy_draws
Format
#36
Conversation
…nferencedata-into-tidy_draws-format
…nferencedata-into-tidy_draws-format
…nferencedata-into-tidy_draws-format
…nferencedata-into-tidy_draws-format
…nferencedata-into-tidy_draws-format
…nferencedata-into-tidy_draws-format
…nferencedata-into-tidy_draws-format
…nferencedata-into-tidy_draws-format
…nferencedata-into-tidy_draws-format
…nferencedata-into-tidy_draws-format
A small note that the following historically worked but assumes all chains have the same number of iterations (in tidy-lingo): tidy_dfs = {
group: (
idata_df.select("chain", "draw", cs.starts_with(f"('{group}',"))
.rename(
{
col: col.split(", ")[1].strip("')")
for col in idata_df.columns
if col.startswith(f"('{group}',")
}
)
# draw in arviz is iteration in tidybayes
.rename({"draw": ".iteration", "chain": ".chain"})
.unpivot(
index=[".chain", ".iteration"],
variable_name="variable",
value_name="value",
)
.with_columns(
pl.col("variable").str.replace(r"\[.*\]", "").alias("variable")
)
.with_columns(pl.col(".iteration") + 1, pl.col(".chain") + 1)
.with_columns(
(pl.col(".iteration").n_unique()).alias("draws_per_chain"),
)
.with_columns(
(
((pl.col(".chain") - 1) * pl.col("draws_per_chain"))
+ pl.col(".iteration")
).alias(".draw")
)
.pivot(
values="value",
index=[".chain", ".iteration", ".draw"],
columns="variable",
aggregate_function="first",
)
.sort([".chain", ".iteration", ".draw"])
)
for group in groups
} The method which does take into account the number of iterations per chain: tidy_dfs = {
group: (
idata_df.select("chain", "draw", cs.starts_with(f"('{group}',"))
.rename(
{
col: col.split(", ")[1].strip("')")
for col in idata_df.columns
if col.startswith(f"('{group}',")
}
)
# draw in arviz is iteration in tidybayes
.rename({"draw": ".iteration", "chain": ".chain"})
.unpivot(
index=[".chain", ".iteration"],
variable_name="variable",
value_name="value",
)
.with_columns(
pl.col("variable").str.replace(r"\[.*\]", "").alias("variable")
)
.with_columns(
pl.col(".iteration") + 1,
pl.col(".chain") + 1)
.pivot(
values="value",
index=[".chain", ".iteration"],
columns="variable",
aggregate_function="first",
)
.sort([".chain", ".iteration", ".draw"])
.with_row_count(name=".draw", offset=1)
)
for group in groups
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few small but important things. Thanks, @AFg6K7h4fhy2!
forecasttools/idata_to_tidy.py
Outdated
values="value", | ||
index=[".chain", ".iteration"], | ||
columns="variable", | ||
aggregate_function="first", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be None
per the docs, no? (unless I'm misunderstanding what your goal is with this operation). You want one column for each unique value of "variable"
for a given ".chain"
and ".iteration"
.
aggregate_function="first", | |
aggregate_function=None, |
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pivot.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the actual pyrenew
inference data object, but not for the other examples I have made, using None
and not first
produces this error (I originally used None
before making a notebook & tests with the pyrenew-hew
InferenceData object given to me):
FAILED tests/test_idata_to_tidy.py::test_posterior_predictive_group - polars.exceptions.ComputeError: found multiple elements in the same group, please specify an aggregation function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That suggests something is not correct upstream. Perhaps the variable name regex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the source of this issue is the line in col: col.split(", ")[1].strip("')")
.rename(
{
col: col.split(", ")[1].strip("')")
for col in idata_df.columns
if col.startswith(f"('{group}',")
}
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed
col: col.split(", ")[1].strip("')")
to
col: re.search(r",\s*'?(.+?)'?\)", col).group(1)
which I got incorrect a handful of times but which I believe works now.
The target of this regex looks like, e.g., "('posterior', 'alpha')"
.
In the expression, the ,
matches comma; in \s*
, the \s
is spaces and *
for zero or more spaces; in '?
, the '
means single quote and ?
means optional (0 or 1 times); in (.+?)
, the (
and )
capture whatever is in the parenthesis with .
means non-newline character, +
means one or more characters, and ?
(which I am not sure is necessary to include here) tries to get this capture in as small a job as possible (this capture will get alpha
in the example target); the '?
is the same as before; the \)
is the single closing parenthesis after the group (the \
is needed for escaping parenthesis; and .group(1)
in re
gets the first item captured by (.+?)
.
If I am missing anything or wrote something inaccurately, please, reader, let me know.
The aggregate function is not successfully set to None
and not first
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you write a test to demonstrate the case where the previous version was failing and the new version succeeds? They both work correctly on the example provided in your comment: "('posterior', 'alpha')"
.
I do understand the subtle difference between the two approaches, but I do not understand why the split approach did not work in practice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The split approach did not work on the InferenceData object from pyrenew-hew
in tests when the aggregate function was set to None
but did work when the aggregate function was set to first
. Yes, both the split and re
approaches work on "('posterior', 'alpha')".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for comment + I will write test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The split approach did not work on the InferenceData object from pyrenew-hew in tests when the aggregate function was set to None but did work when the aggregate function was set to first. Yes, both the split and re approaches work on "('posterior', 'alpha')".
I should probably investigate why, exactly, aggregate=None w/ split doesn't work for the Pyrenew inference data.
assert set(result.keys()) == set(simple_inference_data.groups()) | ||
|
||
|
||
def test_tidydraws_format(simple_inference_data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice, but I think it's important also to check the correct tidying of array-valued parameters. The example idata and associated test from forecasttools-R are nice and should be easily portable to Python
https://github.com/CDCgov/forecasttools/blob/main/data-raw/ex_inferencedata_dataframe.R
https://github.com/CDCgov/forecasttools/blob/main/tests/testthat/test_inferencedata_dataframe_to_tidydraws.R
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will save this task for 2025-03-14.
Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>
forecasttools/idata_to_tidy.py
Outdated
@@ -54,7 +56,7 @@ def convert_inference_data_to_tidydraws( | |||
idata_df.select("chain", "draw", cs.starts_with(f"('{group}',")) | |||
.rename( | |||
{ | |||
col: col.split(", ")[1].strip("')") | |||
col: re.search(r",\s*'?(.+?)'?\)", col).group(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like this would be better handled by providing a lambda rename mapping to .rename()
rather than via dictionary comprehension. See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.rename.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hadn't thought of this. In the back of mind there are reservations I've seen from others that I haven't verified regarding about lambda mappings but they're not sufficient to not act here. Polars rename
seems apt. Thank you for the quick comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the code to:
.rename(
lambda col: re.search(r",\s*'?(.+?)'?\)", col).group(1)
if col.startswith(f"('{group}',")
else col
)
got
forecasttools/idata_to_tidy.py:66:40: B023 Function definition does not bind loop variable `group`
then changed to:
.rename(
lambda col, group=group: re.search(
r",\s*'?(.+?)'?\)", col
).group(1)
if col.startswith(f"('{group}',")
else col
)
which works w/ tests & linting.
Commenting above for my future self.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Readability concern: at least two different senses of "group" present here.
-
I'd make it a one argument
lambda
and exclude it from linting. Making "group" an argument that defaults to the value ofgroup
satisfies the linter but hurts readability imo
forecasttools/idata_to_tidy.py:66:40: B023 Function definition does not bind loop variable `group`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re: #36 (comment)
Will do. I agree, group=group
seems extraneous but my usual response is to defer to the linter.
For the scope of this PR, please refer to issue #18 .