Skip to content

Function To Cast InferenceData Into tidy_draws Format #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 88 commits into
base: main
Choose a base branch
from

Conversation

AFg6K7h4fhy2
Copy link
Collaborator

@AFg6K7h4fhy2 AFg6K7h4fhy2 commented Oct 28, 2024

For the scope of this PR, please refer to issue #18 .

@AFg6K7h4fhy2 AFg6K7h4fhy2 self-assigned this Oct 28, 2024
@AFg6K7h4fhy2 AFg6K7h4fhy2 linked an issue Oct 28, 2024 that may be closed by this pull request
@AFg6K7h4fhy2 AFg6K7h4fhy2 added feature A new tool or utility being added. High Priority A task that is of higher relative priority. labels Oct 28, 2024
@AFg6K7h4fhy2 AFg6K7h4fhy2 added this to the [October 28, November 8] milestone Oct 28, 2024
@AFg6K7h4fhy2 AFg6K7h4fhy2 added Medium Priority A task that is of medium relative priority. and removed High Priority A task that is of higher relative priority. labels Oct 28, 2024
@AFg6K7h4fhy2 AFg6K7h4fhy2 mentioned this pull request Nov 6, 2024
@AFg6K7h4fhy2
Copy link
Collaborator Author

A small note that the following historically worked but assumes all chains have the same number of iterations (in tidy-lingo):

tidy_dfs = {
        group: (
            idata_df.select("chain", "draw", cs.starts_with(f"('{group}',"))
            .rename(
                {
                    col: col.split(", ")[1].strip("')")
                    for col in idata_df.columns
                    if col.startswith(f"('{group}',")
                }
            )
            # draw in arviz is iteration in tidybayes
            .rename({"draw": ".iteration", "chain": ".chain"})
            .unpivot(
                index=[".chain", ".iteration"],
                variable_name="variable",
                value_name="value",
            )
            .with_columns(
                pl.col("variable").str.replace(r"\[.*\]", "").alias("variable")
            )
            .with_columns(pl.col(".iteration") + 1, pl.col(".chain") + 1)
            .with_columns(
                (pl.col(".iteration").n_unique()).alias("draws_per_chain"),
            )
            .with_columns(
                (
                    ((pl.col(".chain") - 1) * pl.col("draws_per_chain"))
                    + pl.col(".iteration")
                ).alias(".draw")
            )
            .pivot(
                values="value",
                index=[".chain", ".iteration", ".draw"],
                columns="variable",
                aggregate_function="first",
            )
            .sort([".chain", ".iteration", ".draw"])
        )
        for group in groups
    }

The method which does take into account the number of iterations per chain:

tidy_dfs = {
        group: (
            idata_df.select("chain", "draw", cs.starts_with(f"('{group}',"))
            .rename(
                {
                    col: col.split(", ")[1].strip("')")
                    for col in idata_df.columns
                    if col.startswith(f"('{group}',")
                }
            )
            # draw in arviz is iteration in tidybayes
            .rename({"draw": ".iteration", "chain": ".chain"})
            .unpivot(
                index=[".chain", ".iteration"],
                variable_name="variable",
                value_name="value",
            )
            .with_columns(
                pl.col("variable").str.replace(r"\[.*\]", "").alias("variable")
            )
            .with_columns(
                pl.col(".iteration") + 1, 
                pl.col(".chain") + 1)
            .pivot(
                values="value",
                index=[".chain", ".iteration"],
                columns="variable",
                aggregate_function="first",
            )
            .sort([".chain", ".iteration", ".draw"])
            .with_row_count(name=".draw", offset=1)
        )
        for group in groups
    }

Copy link
Collaborator

@dylanhmorris dylanhmorris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small but important things. Thanks, @AFg6K7h4fhy2!

values="value",
index=[".chain", ".iteration"],
columns="variable",
aggregate_function="first",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be None per the docs, no? (unless I'm misunderstanding what your goal is with this operation). You want one column for each unique value of "variable" for a given ".chain" and ".iteration".

Suggested change
aggregate_function="first",
aggregate_function=None,

https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pivot.html

Copy link
Collaborator Author

@AFg6K7h4fhy2 AFg6K7h4fhy2 Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the actual pyrenew inference data object, but not for the other examples I have made, using None and not first produces this error (I originally used None before making a notebook & tests with the pyrenew-hew InferenceData object given to me):

FAILED tests/test_idata_to_tidy.py::test_posterior_predictive_group - polars.exceptions.ComputeError: found multiple elements in the same group, please specify an aggregation function

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That suggests something is not correct upstream. Perhaps the variable name regex?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the source of this issue is the line in col: col.split(", ")[1].strip("')")

.rename(
    {
        col: col.split(", ")[1].strip("')")
        for col in idata_df.columns
        if col.startswith(f"('{group}',")
    }
)

Copy link
Collaborator Author

@AFg6K7h4fhy2 AFg6K7h4fhy2 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed

col: col.split(", ")[1].strip("')")

to

col: re.search(r",\s*'?(.+?)'?\)", col).group(1)

which I got incorrect a handful of times but which I believe works now.

The target of this regex looks like, e.g., "('posterior', 'alpha')".

In the expression, the , matches comma; in \s*, the \s is spaces and * for zero or more spaces; in '?, the ' means single quote and ? means optional (0 or 1 times); in (.+?), the ( and ) capture whatever is in the parenthesis with . means non-newline character, + means one or more characters, and ? (which I am not sure is necessary to include here) tries to get this capture in as small a job as possible (this capture will get alpha in the example target); the '? is the same as before; the \) is the single closing parenthesis after the group (the \ is needed for escaping parenthesis; and .group(1) in re gets the first item captured by (.+?).

If I am missing anything or wrote something inaccurately, please, reader, let me know.

The aggregate function is not successfully set to None and not first.

Copy link
Contributor

@damonbayer damonbayer Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write a test to demonstrate the case where the previous version was failing and the new version succeeds? They both work correctly on the example provided in your comment: "('posterior', 'alpha')".

I do understand the subtle difference between the two approaches, but I do not understand why the split approach did not work in practice.

Copy link
Collaborator Author

@AFg6K7h4fhy2 AFg6K7h4fhy2 Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split approach did not work on the InferenceData object from pyrenew-hew in tests when the aggregate function was set to None but did work when the aggregate function was set to first. Yes, both the split and re approaches work on "('posterior', 'alpha')".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for comment + I will write test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split approach did not work on the InferenceData object from pyrenew-hew in tests when the aggregate function was set to None but did work when the aggregate function was set to first. Yes, both the split and re approaches work on "('posterior', 'alpha')".

I should probably investigate why, exactly, aggregate=None w/ split doesn't work for the Pyrenew inference data.

assert set(result.keys()) == set(simple_inference_data.groups())


def test_tidydraws_format(simple_inference_data):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice, but I think it's important also to check the correct tidying of array-valued parameters. The example idata and associated test from forecasttools-R are nice and should be easily portable to Python

https://github.com/CDCgov/forecasttools/blob/main/data-raw/ex_inferencedata_dataframe.R
https://github.com/CDCgov/forecasttools/blob/main/tests/testthat/test_inferencedata_dataframe_to_tidydraws.R

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will save this task for 2025-03-14.

@@ -54,7 +56,7 @@ def convert_inference_data_to_tidydraws(
idata_df.select("chain", "draw", cs.starts_with(f"('{group}',"))
.rename(
{
col: col.split(", ")[1].strip("')")
col: re.search(r",\s*'?(.+?)'?\)", col).group(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like this would be better handled by providing a lambda rename mapping to .rename() rather than via dictionary comprehension. See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.rename.html

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hadn't thought of this. In the back of mind there are reservations I've seen from others that I haven't verified regarding about lambda mappings but they're not sufficient to not act here. Polars rename seems apt. Thank you for the quick comment.

Copy link
Collaborator Author

@AFg6K7h4fhy2 AFg6K7h4fhy2 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the code to:

.rename(
    lambda col: re.search(r",\s*'?(.+?)'?\)", col).group(1)
    if col.startswith(f"('{group}',")
    else col
)

got

forecasttools/idata_to_tidy.py:66:40: B023 Function definition does not bind loop variable `group`

then changed to:

.rename(
    lambda col, group=group: re.search(
        r",\s*'?(.+?)'?\)", col
    ).group(1)
    if col.startswith(f"('{group}',")
    else col
)

which works w/ tests & linting.

Commenting above for my future self.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Readability concern: at least two different senses of "group" present here.

  2. I'd make it a one argument lambda and exclude it from linting. Making "group" an argument that defaults to the value of group satisfies the linter but hurts readability imo

forecasttools/idata_to_tidy.py:66:40: B023 Function definition does not bind loop variable `group`

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: #36 (comment)

Will do. I agree, group=group seems extraneous but my usual response is to defer to the linter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A new tool or utility being added. Medium Priority A task that is of medium relative priority.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Function to cast InferenceData into tidy_draws format
3 participants