Skip to content

Conversation

jemmajeffree
Copy link
Contributor

@jemmajeffree jemmajeffree commented Jul 18, 2025

Updated algorithm for nanvar, to use an adapted version of the Schubert and Gertz (2018) paper mentioned in #386, following discussion in #422

Closes #386
Closes #422

flox/core.py Outdated
for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes):
if empty:
# UGLY! but this is because the `var` breaks our design assumptions
if empty and reduction is not var_chunk:
Copy link
Collaborator

@dcherian dcherian Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code path is an "optimization" for chunks that don't contain any valid groups. so group_idx is all -1.
We will need to override full in MultiArray. Look up what the like kwarg does here, it dispatches to the appropriate array type.


The next issue will be that fill_value is a scalar like np.nan but that doesn't work for all our intermediates (e.g. the "count").

  1. My first thought is that MultiArray will need to track a default fill_value per array. For var, this can be initialized to (None, None, 0). If None we use the fill_value passed in; else the default.
  2. The other way would be to hardcode some behaviour in _initialize_aggregation so that agg.fill_value["intermediate"] = ( (fill_value, fill_value, 0), ), and then multi-array can receive that tuple and do the "right thing".

The other place this will matter is in reindex_numpy, which is executed at the combine step. I suspect the second tuple approach is the best.

This bit is hairy, and ill-defined. Let me know if you want me to work through it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm partway through implementing something to work here.

  • How do I trigger this code pathway without brute force overwriting if empty: with if True:
  • When np.full is called, like is a np array not a MultiArray, because it's (I think) the chunk data and bypassing var_chunk (could also be an artefact of the if True override above?). In a pinch, I guess I could add an elif that catches the empty and reduction is var_chunk and co-erce that into a MultiArray, but it's also ugly so I'm hoping you might have better ideas

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking some more, I may have misinterpreted what fill_value is used for. When is it needed for intermediates?

@dcherian
Copy link
Collaborator

This is great progress! Now we reach some much harder parts. I pushed a commit to show where I think the "chunk" function should go and left a few comments. I think the next steps should be to

  1. address those comments;
  2. add a new test to test_core.py with your reproducer (though modified to work with pure numpy arrays);
  3. implement np.full for MultiArray.
  4. dig a bit more in to the "fill value" bits. You'll see that test_groupby_reduce_all fails in a couple of places to do with fillvalue and setitem. This will take some work to fix, but basically it has to do with adding a "fill value" for groups that have no value up to this point.
  5. There's another confusing failure where the MultiArray only has 2 arrays instead of 3. I don't understand how that happens.

@jemmajeffree
Copy link
Contributor Author

Ah yes, we will need to add that

pytest "tests/test_core.py::test_cohorts_nd_by[flox-cohorts-None-nanvar]" --pdb
will drop you in a debugger prompt that you can use to explore the problem. we'll need to generalize this to nD.

What will happens is that we combine the intermediates and concatenate them over the axes being reduced over. And then reduce over those axes. At the end, the result will have size =1 along those axes; which are then squeezed out (see _squeeze_intermediates).

Just to make sure I understand, there are situations in which var_combine needs to be able to handle combining intermediates along multiple axes at once? If so, that's tricky to handle because the equation I'm using merges MultiArray sets of intermediates pairwise. I think the best way to handle it is probably to stack those axes? (or is that bad for memory?) Or possibly deal with dimensions one after the other? Or we could do something awful with a 2D cumulative sum that loops around for calculating the adjustment terms, but that'd pretty much guarantee that the code is unintelligible to anybody else.

@dcherian
Copy link
Collaborator

Unfortunately I think we need to turn into float64 before calculating intermediates,

we could explicitly accumulate to float64 intermediates in _var_chunk. chunk_reduce calls astype(dtype) on the result anyway. Also note that the cumsum bits in _var_combine are the likeliest source of problems here, so those should be accumulating to a 64bit type.

I think the best way to handle it is probably to stack those axes? (or is that bad for memory?)

Yes I think a reshape would be fine; the axes you are reshaping will be contiguous so this will end up being a view of the original, so no extra memory. Just make sure to reshape back to the same dimensionality at the end.

@dcherian
Copy link
Collaborator

Jemma, please let me know how I can help here. I'm happy to tackle some of these failing tests if you prefer

@jemmajeffree
Copy link
Contributor Author

Thanks Deepak. I'm pretty busy the next few weeks, which is just to say if I drop off for a couple days it's not because I've lost interest. Back to normal 23rd September, but I'll try and get this across the line this week or early next.

I've got an sense of how to implement the reshape/multiple axes for var_chunk, and am hoping to get to this soon. I think this will fix most of the failing tests?

Regarding casting to float64, I'm not confident that I've thought through all the edge cases. ie, we probably wouldn't want to cast np.longdouble back to np.float64? Or is it a safe assumption that np.float64 is a good idea for intermediates no matter what the inputs were? If you have a solution that you're happy with, then feel free to fix this one, otherwise I'll get to it when I can.

It's probably pretty obvious, but I've got basically no familiarity with pytest so I'm developing that familiarity while trying to use it here. I'm happy to keep working my way slowly through it, but I'm also happy for you to take on other failing tests.

What do you see as the outstanding tasks to get this PR finalised? Is it just to address the causes of all the failing tests? I think we might have had a couple unresolved comment threads from code reviews that I'll try track down. Any suggestions for how you'd normally keep track of the last few things to do?

@dcherian
Copy link
Collaborator

but I've got basically no familiarity with pytest so I'm developing that familiarity while trying to use it here. I'm happy to keep working my way slowly through it, but I'm also happy for you to take on other failing tests.

pytest is an acquired taste hehe, and it's used in complex ways here. Some tricks that will help:

  1. pip install pdbpp for pdb++, a better debugger.
  2. pytest --pdb will drop you in a debugger prompt at a failure. This will let you inspect the state, and run python statements. Usually I move up (type u for short), till I see code I am familiar with, and then move "down" with d to understand what's happening.
  3. To insert a debugger breakpoint use import pdb; pdb.set_trace() anywhere in the code. Run with pytest -s for it to wait for input at this prompt. So my usual trick is insert that breakpoint and then pytest -s --pdb TEST_NAME.

cast np.longdouble back to np.float64

No one's really using that, or if they are, we can fix it when they complain.

Is it just to address the causes of all the failing tests?

yes happy to merge when tests pass.

Looks like you're just down to the property tests failing?! If so, I can clear those up.

@jemmajeffree
Copy link
Contributor Author

Looks like you're just down to the property tests failing?!

I think so? I wrapped the part of my code that combined along a single axis in a for loop that means it only has to handle one axis at a time. All the rest of the ideas I tried were nasty to implement.

If so, I can clear those up.

Please :)

Copy link
Contributor Author

@jemmajeffree jemmajeffree left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, fair.
I was deliberately trying to avoid triggering the "invalid value encountered in np divide" warning, though I concede it was a bit of a hack.
Do we want to force den == 0 also to be NaN? (as opposed to inf?)

@dcherian dcherian enabled auto-merge (squash) September 17, 2025 03:29
@dcherian dcherian merged commit 9a765ef into xarray-contrib:main Sep 18, 2025
17 checks passed
@dcherian
Copy link
Collaborator

so happy this is in! 👏🏾 👏🏾 👏🏾

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Suggested change to std/var preprocessing to improve precision use more stable var, std algo
2 participants